diff options
Diffstat (limited to 'silx/gui/plot')
47 files changed, 4783 insertions, 2619 deletions
diff --git a/silx/gui/plot/ColorBar.py b/silx/gui/plot/ColorBar.py index fd4d34e..9798123 100644 --- a/silx/gui/plot/ColorBar.py +++ b/silx/gui/plot/ColorBar.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -251,10 +251,13 @@ class ColorBarWidget(qt.QWidget): def _defaultColormapChanged(self, event): """Handle plot default colormap changed""" - if (event['event'] == 'defaultColormapChanged' and - self.getPlot().getActiveImage() is None): - # No active image, take default colormap update into account - self._syncWithDefaultColormap() + if event['event'] == 'defaultColormapChanged': + plot = self.getPlot() + if (plot is not None and + plot.getActiveImage() is None and + plot._getActiveItem(kind='scatter') is None): + # No active item, take default colormap update into account + self._syncWithDefaultColormap() def _syncWithDefaultColormap(self, data=None): """Update colorbar according to plot default colormap""" @@ -801,7 +804,7 @@ class _TickBar(qt.QWidget): if self._norm == colors.Colormap.LINEAR: return 1 - (val - self._vmin) / (self._vmax - self._vmin) elif self._norm == colors.Colormap.LOGARITHM: - return 1 - (numpy.log10(val) - numpy.log10(self._vmin)) / (numpy.log10(self._vmax) - numpy.log(self._vmin)) + return 1 - (numpy.log10(val) - numpy.log10(self._vmin)) / (numpy.log10(self._vmax) - numpy.log10(self._vmin)) else: raise ValueError('Norm is not recognized') @@ -864,7 +867,7 @@ class _TickBar(qt.QWidget): def _guessType(self, font): """Try fo find the better format to display the tick's labels - :param QFont font: the font we want want to use durint the painting + :param QFont font: the font we want to use during the painting """ form = self._getStandardFormat() @@ -873,7 +876,7 @@ class _TickBar(qt.QWidget): for tick in self.ticks: width = max(fm.width(form.format(tick)), width) - # if the length of the string are too long we are mooving to scientific + # if the length of the string are too long we are moving to scientific # display if width > _TickBar._WIDTH_DISP_VAL - _TickBar._LINE_WIDTH: return self._getScientificForm() diff --git a/silx/gui/plot/CompareImages.py b/silx/gui/plot/CompareImages.py index 88b257d..f7c4899 100644 --- a/silx/gui/plot/CompareImages.py +++ b/silx/gui/plot/CompareImages.py @@ -30,6 +30,7 @@ __license__ = "MIT" __date__ = "23/07/2018" +import enum import logging import numpy import weakref @@ -42,7 +43,6 @@ from silx.gui import plot from silx.gui import icons from silx.gui.colors import Colormap from silx.gui.plot import tools -from silx.third_party import enum _logger = logging.getLogger(__name__) diff --git a/silx/gui/plot/ComplexImageView.py b/silx/gui/plot/ComplexImageView.py index bbcb0a5..2523cde 100644 --- a/silx/gui/plot/ComplexImageView.py +++ b/silx/gui/plot/ComplexImageView.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -365,7 +365,7 @@ class ComplexImageView(qt.QWidget): - log10_amplitude_phase: Color-coded phase with log10(amplitude) as alpha. - :rtype: tuple of str + :rtype: List[Mode] """ return tuple(ImageComplexData.Mode) @@ -375,7 +375,12 @@ class ComplexImageView(qt.QWidget): See :meth:`getSupportedVisualizationModes` for the list of supported modes. - :param str mode: The mode to use. + How-to change visualization mode:: + + widget = ComplexImageView() + widget.setVisualizationMode(ComplexImageView.Mode.PHASE) + + :param Mode mode: The mode to use. """ self._plotImage.setVisualizationMode(mode) diff --git a/silx/gui/plot/CurvesROIWidget.py b/silx/gui/plot/CurvesROIWidget.py index 81e684e..b426a23 100644 --- a/silx/gui/plot/CurvesROIWidget.py +++ b/silx/gui/plot/CurvesROIWidget.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2018 European Synchrotron Radiation Facility +# Copyright (c) 2004-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -22,50 +22,43 @@ # THE SOFTWARE. # # ###########################################################################*/ -"""Widget to handle regions of interest (ROI) on curves displayed in a PlotWindow. +""" +Widget to handle regions of interest (:class:`ROI`) on curves displayed in a +:class:`PlotWindow`. This widget is meant to work with :class:`PlotWindow`. - -ROI are defined by : - -- A name (`ROI` column) -- A type. The type is the label of the x axis. - This can be used to apply or not some ROI to a curve and do some post processing. -- The x coordinate of the left limit (`from` column) -- The x coordinate of the right limit (`to` column) -- Raw counts: Sum of the curve's values in the defined Region Of Intereset. - - .. image:: img/rawCounts.png - -- Net counts: Raw counts minus background - - .. image:: img/netCounts.png """ -__authors__ = ["V.A. Sole", "T. Vincent"] +__authors__ = ["V.A. Sole", "T. Vincent", "H. Payno"] __license__ = "MIT" -__date__ = "13/11/2017" +__date__ = "13/03/2018" from collections import OrderedDict - import logging import os import sys -import weakref - +import functools import numpy - from silx.io import dictdump from silx.utils import deprecation - +from silx.utils.weakref import WeakMethodProxy from .. import icons, qt +from silx.gui.plot.items.curve import Curve +from silx.math.combo import min_max +import weakref +from silx.gui.widgets.TableWidget import TableWidget _logger = logging.getLogger(__name__) class CurvesROIWidget(qt.QWidget): - """Widget displaying a table of ROI information. + """ + Widget displaying a table of ROI information. + + Implements also the following behavior: + + * if the roiTable has no ROI when showing create the default ICR one :param parent: See :class:`QWidget` :param str name: The title of this widget @@ -73,19 +66,18 @@ class CurvesROIWidget(qt.QWidget): sigROIWidgetSignal = qt.Signal(object) """Signal of ROIs modifications. - - Modification information if given as a dict with an 'event' key - providing the type of events. - - Type of events: - - - AddROI, DelROI, LoadROI and ResetROI with keys: 'roilist', 'roidict' - - - selectionChanged with keys: 'row', 'col' 'roi', 'key', 'colheader', - 'rowheader' + Modification information if given as a dict with an 'event' key + providing the type of events. + Type of events: + - AddROI, DelROI, LoadROI and ResetROI with keys: 'roilist', 'roidict' + - selectionChanged with keys: 'row', 'col' 'roi', 'key', 'colheader', + 'rowheader' """ sigROISignal = qt.Signal(object) + """Deprecated signal for backward compatibility with silx < 0.7. + Prefer connecting directly to :attr:`CurvesRoiWidget.sigRoiSignal` + """ def __init__(self, parent=None, name=None, plot=None): super(CurvesROIWidget, self).__init__(parent) @@ -93,6 +85,8 @@ class CurvesROIWidget(qt.QWidget): self.setWindowTitle(name) assert plot is not None self._plotRef = weakref.ref(plot) + self._showAllMarkers = False + self.currentROI = None layout = qt.QVBoxLayout(self) layout.setContentsMargins(0, 0, 0, 0) @@ -103,13 +97,22 @@ class CurvesROIWidget(qt.QWidget): self.setHeader() layout.addWidget(self.headerLabel) ############## - self.roiTable = ROITable(self) + widgetAllCheckbox = qt.QWidget(parent=self) + self._showAllCheckBox = qt.QCheckBox("show all ROI", + parent=widgetAllCheckbox) + widgetAllCheckbox.setLayout(qt.QHBoxLayout()) + spacer = qt.QWidget(parent=widgetAllCheckbox) + spacer.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed) + widgetAllCheckbox.layout().addWidget(spacer) + widgetAllCheckbox.layout().addWidget(self._showAllCheckBox) + layout.addWidget(widgetAllCheckbox) + ############## + self.roiTable = ROITable(self, plot=plot) rheight = self.roiTable.horizontalHeader().sizeHint().height() self.roiTable.setMinimumHeight(4 * rheight) - self.fillFromROIDict = self.roiTable.fillFromROIDict - self.getROIListAndDict = self.roiTable.getROIListAndDict layout.addWidget(self.roiTable) self._roiFileDir = qt.QDir.home().absolutePath() + self._showAllCheckBox.toggled.connect(self.roiTable.showAllMarkers) ################# hbox = qt.QWidget(self) @@ -127,7 +130,8 @@ class CurvesROIWidget(qt.QWidget): self.addButton.setToolTip('Remove the selected ROI') self.resetButton = qt.QPushButton(hbox) self.resetButton.setText("Reset") - self.addButton.setToolTip('Clear all created ROIs. We only let the default ROI') + self.addButton.setToolTip('Clear all created ROIs. We only let the ' + 'default ROI') hboxlayout.addWidget(self.addButton) hboxlayout.addWidget(self.delButton) @@ -149,19 +153,22 @@ class CurvesROIWidget(qt.QWidget): layout.addWidget(hbox) + # Signal / Slot connections self.addButton.clicked.connect(self._add) self.delButton.clicked.connect(self._del) self.resetButton.clicked.connect(self._reset) self.loadButton.clicked.connect(self._load) self.saveButton.clicked.connect(self._save) - self.roiTable.sigROITableSignal.connect(self._forward) - self.currentROI = None - self._middleROIMarkerFlag = False + self.roiTable.activeROIChanged.connect(self._emitCurrentROISignal) + self._isConnected = False # True if connected to plot signals self._isInit = False + # expose API + self.getROIListAndDict = self.roiTable.getROIListAndDict + def getPlotWidget(self): """Returns the associated PlotWidget or None @@ -173,10 +180,6 @@ class CurvesROIWidget(qt.QWidget): self._visibilityChangedHandler(visible=True) qt.QWidget.showEvent(self, event) - def hideEvent(self, event): - self._visibilityChangedHandler(visible=False) - qt.QWidget.hideEvent(self, event) - @property def roiFileDir(self): """The directory from which to load/save ROI from/to files.""" @@ -188,135 +191,81 @@ class CurvesROIWidget(qt.QWidget): def roiFileDir(self, roiFileDir): self._roiFileDir = str(roiFileDir) - def setRois(self, roidict, order=None): - """Set the ROIs by providing a dictionary of ROI information. - - The dictionary keys are the ROI names. - Each value is a sub-dictionary of ROI info with the following fields: - - - ``"from"``: x coordinate of the left limit, as a float - - ``"to"``: x coordinate of the right limit, as a float - - ``"type"``: type of ROI, as a string (e.g "channels", "energy") - - - :param roidict: Dictionary of ROIs - :param str order: Field used for ordering the ROIs. - One of "from", "to", "type". - None (default) for no ordering, or same order as specified - in parameter ``roidict`` if provided as an OrderedDict. - """ - if order is None or order.lower() == "none": - roilist = list(roidict.keys()) - else: - assert order in ["from", "to", "type"] - roilist = sorted(roidict.keys(), - key=lambda roi_name: roidict[roi_name].get(order)) - - return self.roiTable.fillFromROIDict(roilist, roidict) + def setRois(self, rois, order=None): + return self.roiTable.setRois(rois, order) def getRois(self, order=None): - """Return the currently defined ROIs, as an ordered dict. + return self.roiTable.getRois(order) - The dictionary keys are the ROI names. - Each value is a sub-dictionary of ROI info with the following fields: + def setMiddleROIMarkerFlag(self, flag=True): + return self.roiTable.setMiddleROIMarkerFlag(flag) - - ``"from"``: x coordinate of the left limit, as a float - - ``"to"``: x coordinate of the right limit, as a float - - ``"type"``: type of ROI, as a string (e.g "channels", "energy") + def _add(self): + """Add button clicked handler""" + def getNextRoiName(): + rois = self.roiTable.getRois(order=None) + roisNames = [] + [roisNames.append(roiName) for roiName in rois] + nrois = len(rois) + if nrois == 0: + return "ICR" + else: + i = 1 + newroi = "newroi %d" % i + while newroi in roisNames: + i += 1 + newroi = "newroi %d" % i + return newroi + roi = ROI(name=getNextRoiName()) - :param order: Field used for ordering the ROIs. - One of "from", "to", "type", "netcounts", "rawcounts". - None (default) to get the same order as displayed in the widget. - :return: Ordered dictionary of ROI information - """ - roilist, roidict = self.roiTable.getROIListAndDict() - if order is None or order.lower() == "none": - ordered_roilist = roilist + if roi.getName() == "ICR": + roi.setType("Default") else: - assert order in ["from", "to", "type", "netcounts", "rawcounts"] - ordered_roilist = sorted(roidict.keys(), - key=lambda roi_name: roidict[roi_name].get(order)) - - return OrderedDict([(name, roidict[name]) for name in ordered_roilist]) + roi.setType(self.getPlotWidget().getXAxis().getLabel()) - def setMiddleROIMarkerFlag(self, flag=True): - """Activate or deactivate middle marker. + xmin, xmax = self.getPlotWidget().getXAxis().getLimits() + fromdata = xmin + 0.25 * (xmax - xmin) + todata = xmin + 0.75 * (xmax - xmin) + if roi.isICR(): + fromdata, dummy0, todata, dummy1 = self._getAllLimits() + roi.setFrom(fromdata) + roi.setTo(todata) - This allows shifting both min and max limits at once, by dragging - a marker located in the middle. - - :param bool flag: True to activate middle ROI marker - """ - if flag: - self._middleROIMarkerFlag = True - else: - self._middleROIMarkerFlag = False + self.roiTable.addRoi(roi) - def _add(self): - """Add button clicked handler""" + # back compatibility pymca roi signals ddict = {} ddict['event'] = "AddROI" - roilist, roidict = self.roiTable.getROIListAndDict() - ddict['roilist'] = roilist - ddict['roidict'] = roidict + ddict['roilist'] = self.roiTable.roidict.values() + ddict['roidict'] = self.roiTable.roidict self.sigROIWidgetSignal.emit(ddict) + # end back compatibility pymca roi signals def _del(self): """Delete button clicked handler""" - row = self.roiTable.currentRow() - if row >= 0: - index = self.roiTable.labels.index('Type') - text = str(self.roiTable.item(row, index).text()) - if text.upper() != 'DEFAULT': - index = self.roiTable.labels.index('ROI') - key = str(self.roiTable.item(row, index).text()) - else: - # This is to prevent deleting ICR ROI, that is - # usually initialized as "Default" type. - return - roilist, roidict = self.roiTable.getROIListAndDict() - row = roilist.index(key) - del roilist[row] - del roidict[key] - if len(roilist) > 0: - currentroi = roilist[0] - else: - currentroi = None - - self.roiTable.fillFromROIDict(roilist=roilist, - roidict=roidict, - currentroi=currentroi) - ddict = {} - ddict['event'] = "DelROI" - ddict['roilist'] = roilist - ddict['roidict'] = roidict - self.sigROIWidgetSignal.emit(ddict) - - def _forward(self, ddict): - """Broadcast events from ROITable signal""" + self.roiTable.deleteActiveRoi() + + # back compatibility pymca roi signals + ddict = {} + ddict['event'] = "DelROI" + ddict['roilist'] = self.roiTable.roidict.values() + ddict['roidict'] = self.roiTable.roidict self.sigROIWidgetSignal.emit(ddict) + # end back compatibility pymca roi signals def _reset(self): """Reset button clicked handler""" + self.roiTable.clear() + self._add() + + # back compatibility pymca roi signals ddict = {} ddict['event'] = "ResetROI" - roilist0, roidict0 = self.roiTable.getROIListAndDict() - index = 0 - for key in roilist0: - if roidict0[key]['type'].upper() == 'DEFAULT': - index = roilist0.index(key) - break - roilist = [] - roidict = {} - if len(roilist0): - roilist.append(roilist0[index]) - roidict[roilist[0]] = {} - roidict[roilist[0]].update(roidict0[roilist[0]]) - self.roiTable.fillFromROIDict(roilist=roilist, roidict=roidict) - ddict['roilist'] = roilist - ddict['roidict'] = roidict + ddict['roilist'] = self.roiTable.roidict.values() + ddict['roidict'] = self.roiTable.roidict self.sigROIWidgetSignal.emit(ddict) + # end back compatibility pymca roi signals def _load(self): """Load button clicked handler""" @@ -334,32 +283,22 @@ class CurvesROIWidget(qt.QWidget): dialog.close() self.roiFileDir = os.path.dirname(outputFile) - self.load(outputFile) + self.roiTable.load(outputFile) + + # back compatibility pymca roi signals + ddict = {} + ddict['event'] = "LoadROI" + ddict['roilist'] = self.roiTable.roidict.values() + ddict['roidict'] = self.roiTable.roidict + self.sigROIWidgetSignal.emit(ddict) + # end back compatibility pymca roi signals def load(self, filename): """Load ROI widget information from a file storing a dict of ROI. :param str filename: The file from which to load ROI """ - rois = dictdump.load(filename) - currentROI = None - if self.roiTable.rowCount(): - item = self.roiTable.item(self.roiTable.currentRow(), 0) - if item is not None: - currentROI = str(item.text()) - - # Remove rawcounts and netcounts from ROIs - for roi in rois['ROI']['roidict'].values(): - roi.pop('rawcounts', None) - roi.pop('netcounts', None) - - self.roiTable.fillFromROIDict(roilist=rois['ROI']['roilist'], - roidict=rois['ROI']['roidict'], - currentroi=currentROI) - - roilist, roidict = self.roiTable.getROIListAndDict() - event = {'event': 'LoadROI', 'roilist': roilist, 'roidict': roidict} - self.sigROIWidgetSignal.emit(event) + self.roiTable.load(filename) def _save(self): """Save button clicked handler""" @@ -396,142 +335,24 @@ class CurvesROIWidget(qt.QWidget): :param str filename: The file to which to save the ROIs """ - roilist, roidict = self.roiTable.getROIListAndDict() - datadict = {'ROI': {'roilist': roilist, 'roidict': roidict}} - dictdump.dump(datadict, filename) + self.roiTable.save(filename) def setHeader(self, text='ROIs'): """Set the header text of this widget""" self.headerLabel.setText("<b>%s<\b>" % text) - def _roiSignal(self, ddict): - """Handle ROI widget signal""" - _logger.debug("CurvesROIWidget._roiSignal %s", str(ddict)) - plot = self.getPlotWidget() - if plot is None: - return - - if ddict['event'] == "AddROI": - xmin, xmax = plot.getXAxis().getLimits() - fromdata = xmin + 0.25 * (xmax - xmin) - todata = xmin + 0.75 * (xmax - xmin) - plot.remove('ROI min', kind='marker') - plot.remove('ROI max', kind='marker') - if self._middleROIMarkerFlag: - plot.remove('ROI middle', kind='marker') - roiList, roiDict = self.roiTable.getROIListAndDict() - nrois = len(roiList) - if nrois == 0: - newroi = "ICR" - fromdata, dummy0, todata, dummy1 = self._getAllLimits() - draggable = False - color = 'black' - else: - # find the next index free for newroi. - for i in range(nrois): - i += 1 - newroi = "newroi %d" % i - if newroi not in roiList: - break - color = 'blue' - draggable = True - plot.addXMarker(fromdata, - legend='ROI min', - text='ROI min', - color=color, - draggable=draggable) - plot.addXMarker(todata, - legend='ROI max', - text='ROI max', - color=color, - draggable=draggable) - if draggable and self._middleROIMarkerFlag: - pos = 0.5 * (fromdata + todata) - plot.addXMarker(pos, - legend='ROI middle', - text="", - color='yellow', - draggable=draggable) - roiList.append(newroi) - roiDict[newroi] = {} - if newroi == "ICR": - roiDict[newroi]['type'] = "Default" - else: - roiDict[newroi]['type'] = plot.getXAxis().getLabel() - roiDict[newroi]['from'] = fromdata - roiDict[newroi]['to'] = todata - self.roiTable.fillFromROIDict(roilist=roiList, - roidict=roiDict, - currentroi=newroi) - self.currentROI = newroi - self.calculateRois() - elif ddict['event'] in ['DelROI', "ResetROI"]: - plot.remove('ROI min', kind='marker') - plot.remove('ROI max', kind='marker') - if self._middleROIMarkerFlag: - plot.remove('ROI middle', kind='marker') - roiList, roiDict = self.roiTable.getROIListAndDict() - roiDictKeys = list(roiDict.keys()) - if len(roiDictKeys): - currentroi = roiDictKeys[0] - else: - # create again the ICR - ddict = {"event": "AddROI"} - return self._roiSignal(ddict) - - self.roiTable.fillFromROIDict(roilist=roiList, - roidict=roiDict, - currentroi=currentroi) - self.currentROI = currentroi - - elif ddict['event'] == 'LoadROI': - self.calculateRois() + @deprecation.deprecated(replacement="calculateRois", + reason="CamelCase convention", + since_version="0.7") + def calculateROIs(self, *args, **kw): + self.calculateRois(*args, **kw) - elif ddict['event'] == 'selectionChanged': - _logger.debug("Selection changed") - self.roilist, self.roidict = self.roiTable.getROIListAndDict() - fromdata = ddict['roi']['from'] - todata = ddict['roi']['to'] - plot.remove('ROI min', kind='marker') - plot.remove('ROI max', kind='marker') - if ddict['key'] == 'ICR': - draggable = False - color = 'black' - else: - draggable = True - color = 'blue' - plot.addXMarker(fromdata, - legend='ROI min', - text='ROI min', - color=color, - draggable=draggable) - plot.addXMarker(todata, - legend='ROI max', - text='ROI max', - color=color, - draggable=draggable) - if draggable and self._middleROIMarkerFlag: - pos = 0.5 * (fromdata + todata) - plot.addXMarker(pos, - legend='ROI middle', - text="", - color='yellow', - draggable=True) - self.currentROI = ddict['key'] - if ddict['colheader'] in ['From', 'To']: - dict0 = {} - dict0['event'] = "SetActiveCurveEvent" - dict0['legend'] = plot.getActiveCurve(just_legend=1) - plot.setActiveCurve(dict0['legend']) - elif ddict['colheader'] == 'Raw Counts': - pass - elif ddict['colheader'] == 'Net Counts': - pass - else: - self._emitCurrentROISignal() + def calculateRois(self, roiList=None, roiDict=None): + """Compute ROI information""" + return self.roiTable.calculateRois() - else: - _logger.debug("Unknown or ignored event %s", ddict['event']) + def showAllMarkers(self, _show=True): + self.roiTable.showAllMarkers(_show) def _getAllLimits(self): """Retrieve the limits based on the curves.""" @@ -565,429 +386,1121 @@ class CurvesROIWidget(qt.QWidget): return xmin, ymin, xmax, ymax - @deprecation.deprecated(replacement="calculateRois", - reason="CamelCase convention") - def calculateROIs(self, *args, **kw): - self.calculateRois(*args, **kw) + def showEvent(self, event): + self._visibilityChangedHandler(visible=True) + qt.QWidget.showEvent(self, event) - def calculateRois(self, roiList=None, roiDict=None): - """Compute ROI information""" - if roiList is None or roiDict is None: - roiList, roiDict = self.roiTable.getROIListAndDict() + def hideEvent(self, event): + self._visibilityChangedHandler(visible=False) + qt.QWidget.hideEvent(self, event) - plot = self.getPlotWidget() - if plot is None: - activeCurve = None - else: - activeCurve = plot.getActiveCurve(just_legend=False) + def _visibilityChangedHandler(self, visible): + """Handle widget's visibility updates. - if activeCurve is None: - xproc = None - yproc = None - self.setHeader() - else: - x = activeCurve.getXData(copy=False) - y = activeCurve.getYData(copy=False) - legend = activeCurve.getLegend() - idx = numpy.argsort(x, kind='mergesort') - xproc = numpy.take(x, idx) - yproc = numpy.take(y, idx) - self.setHeader('ROIs of %s' % legend) - - for key in roiList: - if key == 'ICR': - if xproc is not None: - roiDict[key]['from'] = xproc.min() - roiDict[key]['to'] = xproc.max() - else: - roiDict[key]['from'] = 0 - roiDict[key]['to'] = -1 - fromData = roiDict[key]['from'] - toData = roiDict[key]['to'] - if xproc is not None: - idx = numpy.nonzero((fromData <= xproc) & - (xproc <= toData))[0] - if len(idx): - xw = xproc[idx] - yw = yproc[idx] - rawCounts = yw.sum(dtype=numpy.float) - deltaX = xw[-1] - xw[0] - deltaY = yw[-1] - yw[0] - if deltaX > 0.0: - slope = (deltaY / deltaX) - background = yw[0] + slope * (xw - xw[0]) - netCounts = (rawCounts - - background.sum(dtype=numpy.float)) - else: - netCounts = 0.0 - else: - rawCounts = 0.0 - netCounts = 0.0 - roiDict[key]['rawcounts'] = rawCounts - roiDict[key]['netcounts'] = netCounts - else: - roiDict[key].pop('rawcounts', None) - roiDict[key].pop('netcounts', None) + It is connected to plot signals only when visible. + """ + if visible: + # if no ROI existing yet, add the default one + if self.roiTable.rowCount() is 0: + self._add() + self.calculateRois() - self.roiTable.fillFromROIDict( - roilist=roiList, - roidict=roiDict, - currentroi=self.currentROI if self.currentROI in roiList else None) + def fillFromROIDict(self, *args, **kwargs): + self.roiTable.fillFromROIDict(*args, **kwargs) def _emitCurrentROISignal(self): ddict = {} ddict['event'] = "currentROISignal" - _roiList, roiDict = self.roiTable.getROIListAndDict() - if self.currentROI in roiDict: - ddict['ROI'] = roiDict[self.currentROI] + if self.roiTable.activeRoi is not None: + ddict['ROI'] = self.roiTable.activeRoi.toDict() + ddict['current'] = self.roiTable.activeRoi.getName() else: - self.currentROI = None - ddict['current'] = self.currentROI + ddict['current'] = None self.sigROISignal.emit(ddict) - def _handleROIMarkerEvent(self, ddict): - """Handle plot signals related to marker events.""" - if ddict['event'] == 'markerMoved': + @property + def currentRoi(self): + return self.roiTable.activeRoi - label = ddict['label'] - if label not in ['ROI min', 'ROI max', 'ROI middle']: - return - roiList, roiDict = self.roiTable.getROIListAndDict() - if self.currentROI is None: - return - if self.currentROI not in roiDict: - return +class _FloatItem(qt.QTableWidgetItem): + """ + Simple QTableWidgetItem overloading the < operator to deal with ordering + """ + def __init__(self): + qt.QTableWidgetItem.__init__(self, type=qt.QTableWidgetItem.Type) - plot = self.getPlotWidget() - if plot is None: - return + def __lt__(self, other): + if self.text() in ('', ROITable.INFO_NOT_FOUND): + return False + if other.text() in ('', ROITable.INFO_NOT_FOUND): + return True + return float(self.text()) < float(other.text()) + + +class ROITable(TableWidget): + """Table widget displaying ROI information. + + See :class:`QTableWidget` for constructor arguments. - x = ddict['x'] - - if label == 'ROI min': - roiDict[self.currentROI]['from'] = x - if self._middleROIMarkerFlag: - pos = 0.5 * (roiDict[self.currentROI]['to'] + - roiDict[self.currentROI]['from']) - plot.addXMarker(pos, - legend='ROI middle', - text='', - color='yellow', - draggable=True) - elif label == 'ROI max': - roiDict[self.currentROI]['to'] = x - if self._middleROIMarkerFlag: - pos = 0.5 * (roiDict[self.currentROI]['to'] + - roiDict[self.currentROI]['from']) - plot.addXMarker(pos, - legend='ROI middle', - text='', - color='yellow', - draggable=True) - elif label == 'ROI middle': - delta = x - 0.5 * (roiDict[self.currentROI]['from'] + - roiDict[self.currentROI]['to']) - roiDict[self.currentROI]['from'] += delta - roiDict[self.currentROI]['to'] += delta - plot.addXMarker(roiDict[self.currentROI]['from'], - legend='ROI min', - text='ROI min', - color='blue', - draggable=True) - plot.addXMarker(roiDict[self.currentROI]['to'], - legend='ROI max', - text='ROI max', - color='blue', - draggable=True) + Behavior: listen at the active curve changed only when the widget is + visible. Otherwise won't compute the row and net counts... + """ + + activeROIChanged = qt.Signal() + """Signal emitted when the active roi changed or when the value of the + active roi are changing""" + + COLUMNS_INDEX = OrderedDict([ + ('ID', 0), + ('ROI', 1), + ('Type', 2), + ('From', 3), + ('To', 4), + ('Raw Counts', 5), + ('Net Counts', 6), + ('Raw Area', 7), + ('Net Area', 8), + ]) + + COLUMNS = list(COLUMNS_INDEX.keys()) + + INFO_NOT_FOUND = '????????' + + def __init__(self, parent=None, plot=None, rois=None): + super(ROITable, self).__init__(parent) + self._showAllMarkers = False + self._userIsEditingRoi = False + """bool used to avoid conflict when editing the ROI object""" + self._isConnected = False + self._roiToItems = {} + self._roiDict = {} + """dict of ROI object. Key is ROi id, value is the ROI object""" + self._markersHandler = _RoiMarkerManager() + + """ + Associate for each marker legend used when the `_showAllMarkers` option + is active a roi. + """ + self.setColumnCount(len(self.COLUMNS)) + self.setPlot(plot) + self.__setTooltip() + self.setSortingEnabled(True) + self.itemChanged.connect(self._itemChanged) + + @property + def roidict(self): + return self._getRoiDict() + + @property + def activeRoi(self): + return self._markersHandler._activeRoi + + def _getRoiDict(self): + ddict = {} + for id in self._roiDict: + ddict[self._roiDict[id].getName()] = self._roiDict[id] + return ddict + + def clear(self): + """ + .. note:: clear the interface only. keep the roidict... + """ + self._markersHandler.clear() + self._roiToItems = {} + self._roiDict = {} + + qt.QTableWidget.clear(self) + self.setRowCount(0) + self.setHorizontalHeaderLabels(self.COLUMNS) + header = self.horizontalHeader() + if hasattr(header, 'setSectionResizeMode'): # Qt5 + header.setSectionResizeMode(qt.QHeaderView.ResizeToContents) + else: # Qt4 + header.setResizeMode(qt.QHeaderView.ResizeToContents) + self.sortByColumn(0, qt.Qt.AscendingOrder) + self.hideColumn(self.COLUMNS_INDEX['ID']) + + def setPlot(self, plot): + self.clear() + self.plot = plot + + def __setTooltip(self): + self.horizontalHeaderItem(self.COLUMNS_INDEX['ROI']).setToolTip( + 'Region of interest identifier') + self.horizontalHeaderItem(self.COLUMNS_INDEX['Type']).setToolTip( + 'Type of the ROI') + self.horizontalHeaderItem(self.COLUMNS_INDEX['From']).setToolTip( + 'X-value of the min point') + self.horizontalHeaderItem(self.COLUMNS_INDEX['To']).setToolTip( + 'X-value of the max point') + self.horizontalHeaderItem(self.COLUMNS_INDEX['Raw Counts']).setToolTip( + 'Estimation of the integral between y=0 and the selected curve') + self.horizontalHeaderItem(self.COLUMNS_INDEX['Net Counts']).setToolTip( + 'Estimation of the integral between the segment [maxPt, minPt] ' + 'and the selected curve') + + def setRois(self, rois, order=None): + """Set the ROIs by providing a dictionary of ROI information. + + The dictionary keys are the ROI names. + Each value is a sub-dictionary of ROI info with the following fields: + + - ``"from"``: x coordinate of the left limit, as a float + - ``"to"``: x coordinate of the right limit, as a float + - ``"type"``: type of ROI, as a string (e.g "channels", "energy") + + + :param roidict: Dictionary of ROIs + :param str order: Field used for ordering the ROIs. + One of "from", "to", "type". + None (default) for no ordering, or same order as specified + in parameter ``roidict`` if provided as an OrderedDict. + """ + assert order in [None, "from", "to", "type"] + self.clear() + + # backward compatibility since 0.10.0 + if isinstance(rois, dict): + for roiName, roi in rois.items(): + roi['name'] = roiName + _roi = ROI._fromDict(roi) + self.addRoi(_roi) + else: + for roi in rois: + assert isinstance(roi, ROI) + self.addRoi(roi) + self._updateMarkers() + + def addRoi(self, roi): + """ + + :param :class:`ROI` roi: roi to add to the table + """ + assert isinstance(roi, ROI) + self._getItem(name='ID', row=None, roi=roi) + self._roiDict[roi.getID()] = roi + self._markersHandler.add(roi, _RoiMarkerHandler(roi, self.plot)) + self._updateRoiInfo(roi.getID()) + callback = functools.partial(WeakMethodProxy(self._updateRoiInfo), + roi.getID()) + roi.sigChanged.connect(callback) + # set it as the active one + self.setActiveRoi(roi) + + def _getItem(self, name, row, roi): + if row: + item = self.item(row, self.COLUMNS_INDEX[name]) + else: + item = None + if item: + return item + else: + if name == 'ID': + assert roi + if roi.getID() in self._roiToItems: + return self._roiToItems[roi.getID()] + else: + # create a new row + row = self.rowCount() + self.setRowCount(self.rowCount() + 1) + item = qt.QTableWidgetItem(str(roi.getID()), + type=qt.QTableWidgetItem.Type) + self._roiToItems[roi.getID()] = item + elif name == 'ROI': + item = qt.QTableWidgetItem(roi.getName() if roi else '', + type=qt.QTableWidgetItem.Type) + if roi.getName().upper() in ('ICR', 'DEFAULT'): + item.setFlags(qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled) + else: + item.setFlags(qt.Qt.ItemIsSelectable | + qt.Qt.ItemIsEnabled | + qt.Qt.ItemIsEditable) + elif name == 'Type': + item = qt.QTableWidgetItem(type=qt.QTableWidgetItem.Type) + item.setFlags((qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled)) + elif name in ('To', 'From'): + item = _FloatItem() + if roi.getName().upper() in ('ICR', 'DEFAULT'): + item.setFlags(qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled) + else: + item.setFlags(qt.Qt.ItemIsSelectable | + qt.Qt.ItemIsEnabled | + qt.Qt.ItemIsEditable) + elif name in ('Raw Counts', 'Net Counts', 'Raw Area', 'Net Area'): + item = _FloatItem() + item.setFlags((qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled)) else: - return - self.calculateRois(roiList, roiDict) - self._emitCurrentROISignal() + raise ValueError('item type not recognized') + + self.setItem(row, self.COLUMNS_INDEX[name], item) + return item + + def _itemChanged(self, item): + def getRoi(): + IDItem = self.item(item.row(), self.COLUMNS_INDEX['ID']) + assert IDItem + id = int(IDItem.text()) + assert id in self._roiDict + roi = self._roiDict[id] + return roi + + def signalChanged(roi): + if self.activeRoi and roi.getID() == self.activeRoi.getID(): + self.activeROIChanged.emit() + + self._userIsEditingRoi = True + if item.column() in (self.COLUMNS_INDEX['To'], self.COLUMNS_INDEX['From']): + roi = getRoi() + + if item.text() not in ('', self.INFO_NOT_FOUND): + try: + value = float(item.text()) + except ValueError: + value = 0 + changed = False + if item.column() == self.COLUMNS_INDEX['To']: + if value != roi.getTo(): + roi.setTo(value) + changed = True + else: + assert(item.column() == self.COLUMNS_INDEX['From']) + if value != roi.getFrom(): + roi.setFrom(value) + changed = True + if changed: + self._updateMarker(roi.getName()) + signalChanged(roi) + + if item.column() is self.COLUMNS_INDEX['ROI']: + roi = getRoi() + if roi.getName() != item.text(): + roi.setName(item.text()) + self._markersHandler.getMarkerHandler(roi.getID()).updateTexts() + signalChanged(roi) + + self._userIsEditingRoi = False + + def deleteActiveRoi(self): + """ + remove the current active roi + """ + activeItems = self.selectedItems() + if len(activeItems) is 0: + return + roiToRm = set() + for item in activeItems: + row = item.row() + itemID = self.item(row, self.COLUMNS_INDEX['ID']) + roiToRm.add(self._roiDict[int(itemID.text())]) + [self.removeROI(roi) for roi in roiToRm] + self.setActiveRoi(None) + + def removeROI(self, roi): + """ + remove the requested roi - def _visibilityChangedHandler(self, visible): - """Handle widget's visibility updates. + :param str name: the name of the roi to remove from the table + """ + if roi and roi.getID() in self._roiToItems: + item = self._roiToItems[roi.getID()] + self.removeRow(item.row()) + del self._roiToItems[roi.getID()] - It is connected to plot signals only when visible. + assert roi.getID() in self._roiDict + del self._roiDict[roi.getID()] + self._markersHandler.remove(roi) + + callback = functools.partial(WeakMethodProxy(self._updateRoiInfo), + roi.getID()) + roi.sigChanged.connect(callback) + + def setActiveRoi(self, roi): """ - plot = self.getPlotWidget() + Define the given roi as the active one. - if visible: - if not self._isInit: - # Deferred ROI widget init finalization - self._finalizeInit() - - if not self._isConnected and plot is not None: - plot.sigPlotSignal.connect(self._handleROIMarkerEvent) - plot.sigActiveCurveChanged.connect( - self._activeCurveChanged) - self._isConnected = True + .. warning:: this roi should already be registred / added to the table - self.calculateRois() + :param :class:`ROI` roi: the roi to defined as active + """ + if roi is None: + self.clearSelection() + self._markersHandler.setActiveRoi(None) + self.activeROIChanged.emit() else: - if self._isConnected: - if plot is not None: - plot.sigPlotSignal.disconnect(self._handleROIMarkerEvent) - plot.sigActiveCurveChanged.disconnect( - self._activeCurveChanged) - self._isConnected = False + assert isinstance(roi, ROI) + if roi and roi.getID() in self._roiToItems.keys(): + self.selectRow(self._roiToItems[roi.getID()].row()) + self._markersHandler.setActiveRoi(roi) + self.activeROIChanged.emit() + + def _updateRoiInfo(self, roiID): + if self._userIsEditingRoi is True: + return + if roiID not in self._roiDict: + return + roi = self._roiDict[roiID] + if roi.isICR(): + activeCurve = self.plot.getActiveCurve() + if activeCurve: + xData = activeCurve.getXData() + if len(xData) > 0: + min, max = min_max(xData) + roi.blockSignals(True) + roi.setFrom(min) + roi.setTo(max) + roi.blockSignals(False) + + itemID = self._getItem(name='ID', roi=roi, row=None) + itemName = self._getItem(name='ROI', row=itemID.row(), roi=roi) + itemName.setText(roi.getName()) + + itemType = self._getItem(name='Type', row=itemID.row(), roi=roi) + itemType.setText(roi.getType() or self.INFO_NOT_FOUND) + + itemFrom = self._getItem(name='From', row=itemID.row(), roi=roi) + fromdata = str(roi.getFrom()) if roi.getFrom() is not None else self.INFO_NOT_FOUND + itemFrom.setText(fromdata) + + itemTo = self._getItem(name='To', row=itemID.row(), roi=roi) + todata = str(roi.getTo()) if roi.getTo() is not None else self.INFO_NOT_FOUND + itemTo.setText(todata) + + rawCounts, netCounts = roi.computeRawAndNetCounts( + curve=self.plot.getActiveCurve(just_legend=False)) + itemRawCounts = self._getItem(name='Raw Counts', row=itemID.row(), + roi=roi) + rawCounts = str(rawCounts) if rawCounts is not None else self.INFO_NOT_FOUND + itemRawCounts.setText(rawCounts) + + itemNetCounts = self._getItem(name='Net Counts', row=itemID.row(), + roi=roi) + netCounts = str(netCounts) if netCounts is not None else self.INFO_NOT_FOUND + itemNetCounts.setText(netCounts) + + rawArea, netArea = roi.computeRawAndNetArea( + curve=self.plot.getActiveCurve(just_legend=False)) + itemRawArea = self._getItem(name='Raw Area', row=itemID.row(), + roi=roi) + rawArea = str(rawArea) if rawArea is not None else self.INFO_NOT_FOUND + itemRawArea.setText(rawArea) + + itemNetArea = self._getItem(name='Net Area', row=itemID.row(), + roi=roi) + netArea = str(netArea) if netArea is not None else self.INFO_NOT_FOUND + itemNetArea.setText(netArea) + + if self.activeRoi and roi.getID() == self.activeRoi.getID(): + self.activeROIChanged.emit() + + def currentChanged(self, current, previous): + if previous and current.row() != previous.row() and current.row() >= 0: + roiItem = self.item(current.row(), + self.COLUMNS_INDEX['ID']) + + assert roiItem + self.setActiveRoi(self._roiDict[int(roiItem.text())]) + self._markersHandler.updateAllMarkers() + qt.QTableWidget.currentChanged(self, current, previous) + + @deprecation.deprecated(reason="Removed", + replacement="roidict and roidict.values()", + since_version="0.10.0") + def getROIListAndDict(self): + """ - def _activeCurveChanged(self, *args): - """Recompute ROIs when active curve changed.""" - self.calculateRois() + :return: the list of roi objects and the dictionary of roi name to roi + object. + """ + roidict = self._roiDict + return list(roidict.values()), roidict - def _finalizeInit(self): - self._isInit = True - self.sigROIWidgetSignal.connect(self._roiSignal) - # initialize with the ICR if no ROi existing yet - if len(self.getRois()) is 0: - self._roiSignal({'event': "AddROI"}) + def calculateRois(self, roiList=None, roiDict=None): + """ + Update values of all registred rois (raw and net counts in particular) + :param roiList: deprecated parameter + :param roiDict: deprecated parameter + """ + if roiDict: + deprecation.deprecated_warning(name='roiDict', type_='Parameter', + reason='Unused parameter', + since_version="0.10.0") + if roiList: + deprecation.deprecated_warning(name='roiList', type_='Parameter', + reason='Unused parameter', + since_version="0.10.0") + + for roiID in self._roiDict: + self._updateRoiInfo(roiID) + + def _updateMarker(self, roiID): + """Make sure the marker of the given roi name is updated""" + if self._showAllMarkers or (self.activeRoi + and self.activeRoi.getName() == roiID): + self._updateMarkers() + + def _updateMarkers(self): + if self._showAllMarkers is True: + self._markersHandler.updateMarkers() + else: + if not self.activeRoi or not self.plot: + return + assert isinstance(self.activeRoi, ROI) + markerHandler = self._markersHandler.getMarkerHandler(self.activeRoi.getID()) + if markerHandler is not None: + markerHandler.updateMarkers() -class ROITable(qt.QTableWidget): - """Table widget displaying ROI information. + def getRois(self, order): + """ + Return the currently defined ROIs, as an ordered dict. - See :class:`QTableWidget` for constructor arguments. - """ + The dictionary keys are the ROI names. + Each value is a :class:`ROI` object.. - sigROITableSignal = qt.Signal(object) - """Signal of ROI table modifications. - """ + :param order: Field used for ordering the ROIs. + One of "from", "to", "type", "netcounts", "rawcounts". + None (default) to get the same order as displayed in the widget. + :return: Ordered dictionary of ROI information + """ - def __init__(self, *args, **kwargs): - super(ROITable, self).__init__(*args, **kwargs) - self.setRowCount(1) - self.labels = 'ROI', 'Type', 'From', 'To', 'Raw Counts', 'Net Counts' - self.setColumnCount(len(self.labels)) - self.setSortingEnabled(False) + if order is None or order.lower() == "none": + ordered_roilist = list(self._roiDict.values()) + res = OrderedDict([(roi.getName(), self._roiDict[roi.getID()]) for roi in ordered_roilist]) + else: + assert order in ["from", "to", "type", "netcounts", "rawcounts"] + ordered_roilist = sorted(self._roiDict.keys(), + key=lambda roi_id: self._roiDict[roi_id].get(order)) + res = OrderedDict([(roi.getName(), self._roiDict[id]) for id in ordered_roilist]) + + return res + + def save(self, filename): + """ + Save current ROIs of the widget as a dict of ROI to a file. + + :param str filename: The file to which to save the ROIs + """ + roilist = [] + roidict = {} + for roiID, roi in self._roiDict.items(): + roilist.append(roi.toDict()) + roidict[roi.getName()] = roi.toDict() + datadict = {'ROI': {'roilist': roilist, 'roidict': roidict}} + dictdump.dump(datadict, filename) - for index, label in enumerate(self.labels): - item = self.horizontalHeaderItem(index) - if item is None: - item = qt.QTableWidgetItem(label, - qt.QTableWidgetItem.Type) - item.setText(label) - self.setHorizontalHeaderItem(index, item) + def load(self, filename): + """ + Load ROI widget information from a file storing a dict of ROI. - self.roidict = {} - self.roilist = [] + :param str filename: The file from which to load ROI + """ + roisDict = dictdump.load(filename) + rois = [] - self.building = False - self.fillFromROIDict(roilist=self.roilist, roidict=self.roidict) + # Remove rawcounts and netcounts from ROIs + for roiDict in roisDict['ROI']['roidict'].values(): + roiDict.pop('rawcounts', None) + roiDict.pop('netcounts', None) + rois.append(ROI._fromDict(roiDict)) - self.cellClicked[(int, int)].connect(self._cellClickedSlot) - self.cellChanged[(int, int)].connect(self._cellChangedSlot) - verticalHeader = self.verticalHeader() - verticalHeader.sectionClicked[int].connect(self._rowChangedSlot) + self.setRois(rois) - self.__setTooltip() + def showAllMarkers(self, _show=True): + """ - def __setTooltip(self): - assert(self.labels[0] == 'ROI') - self.horizontalHeaderItem(0).setToolTip('Region of interest identifier') - assert(self.labels[1] == 'Type') - self.horizontalHeaderItem(1).setToolTip('Type of the ROI') - assert(self.labels[2] == 'From') - self.horizontalHeaderItem(2).setToolTip('X-value of the min point') - assert(self.labels[3] == 'To') - self.horizontalHeaderItem(3).setToolTip('X-value of the max point') - assert(self.labels[4] == 'Raw Counts') - self.horizontalHeaderItem(4).setToolTip('Estimation of the integral \ - between y=0 and the selected curve') - assert(self.labels[5] == 'Net Counts') - self.horizontalHeaderItem(5).setToolTip('Estimation of the integral \ - between the segment [maxPt, minPt] and the selected curve') + :param bool _show: if true show all the markers of all the ROIs + boundaries otherwise will only show the one of + the active ROI. + """ + self._markersHandler.setShowAllMarkers(_show) + + def setMiddleROIMarkerFlag(self, flag=True): + """ + Activate or deactivate middle marker. + + This allows shifting both min and max limits at once, by dragging + a marker located in the middle. + + :param bool flag: True to activate middle ROI marker + """ + self._markersHandler._middleROIMarkerFlag = flag + + def _handleROIMarkerEvent(self, ddict): + """Handle plot signals related to marker events.""" + if ddict['event'] == 'markerMoved': + label = ddict['label'] + roiID = self._markersHandler.getRoiID(markerID=label) + if roiID: + self._markersHandler.changePosition(markerID=label, + x=ddict['x']) + self._updateRoiInfo(roiID) + + def showEvent(self, event): + self._visibilityChangedHandler(visible=True) + qt.QWidget.showEvent(self, event) + + def hideEvent(self, event): + self._visibilityChangedHandler(visible=False) + qt.QWidget.hideEvent(self, event) + + def _visibilityChangedHandler(self, visible): + """Handle widget's visibility updates. + + It is connected to plot signals only when visible. + """ + if visible: + assert self.plot + if self._isConnected is False: + self.plot.sigPlotSignal.connect(self._handleROIMarkerEvent) + self.plot.sigActiveCurveChanged.connect(self._activeCurveChanged) + self._isConnected = True + self.calculateRois() + else: + if self._isConnected: + self.plot.sigPlotSignal.disconnect(self._handleROIMarkerEvent) + self.plot.sigActiveCurveChanged.disconnect(self._activeCurveChanged) + self._isConnected = False + + def _activeCurveChanged(self, curve): + self.calculateRois() + + def setCountsVisible(self, visible): + """ + Display the columns relative to areas or not + + :param bool visible: True if the columns 'Raw Area' and 'Net Area' + should be visible. + """ + if visible is True: + self.showColumn(self.COLUMNS_INDEX['Raw Counts']) + self.showColumn(self.COLUMNS_INDEX['Net Counts']) + else: + self.hideColumn(self.COLUMNS_INDEX['Raw Counts']) + self.hideColumn(self.COLUMNS_INDEX['Net Counts']) + + def setAreaVisible(self, visible): + """ + Display the columns relative to areas or not + + :param bool visible: True if the columns 'Raw Area' and 'Net Area' + should be visible. + """ + if visible is True: + self.showColumn(self.COLUMNS_INDEX['Raw Area']) + self.showColumn(self.COLUMNS_INDEX['Net Area']) + else: + self.hideColumn(self.COLUMNS_INDEX['Raw Area']) + self.hideColumn(self.COLUMNS_INDEX['Net Area']) def fillFromROIDict(self, roilist=(), roidict=None, currentroi=None): - """Set the ROIs by providing a list of ROIÂ names and a dictionary - of ROI information for each ROI. + """ + This function API is kept for compatibility. + But `setRois` should be preferred. + Set the ROIs by providing a list of ROI names and a dictionary + of ROI information for each ROI. The ROI names must match an existing dictionary key. The name list is used to provide an order for the ROIs. - The dictionary's values are sub-dictionaries containing 3 mandatory fields: - - ``"from"``: x coordinate of the left limit, as a float - - ``"to"``: x coordinate of the right limit, as a float - - ``"type"``: type of ROI, as a string (e.g "channels", "energy") + - ``"from"``: x coordinate of the left limit, as a float + - ``"to"``: x coordinate of the right limit, as a float + - ``"type"``: type of ROI, as a string (e.g "channels", "energy") :param roilist: List of ROI names (keys of roidict) :type roilist: List :param dict roidict: Dict of ROI information :param currentroi: Name of the selected ROI or None (no selection) """ - if roidict is None: - roidict = {} - - self.building = True - line0 = 0 - self.roilist = [] - self.roidict = {} - for key in roilist: - if key in roidict.keys(): - roi = roidict[key] - self.roilist.append(key) - self.roidict[key] = {} - self.roidict[key].update(roi) - line0 = line0 + 1 - nlines = self.rowCount() - if (line0 > nlines): - self.setRowCount(line0) - line = line0 - 1 - self.roidict[key]['line'] = line - ROI = key - roitype = "%s" % roi['type'] - fromdata = "%6g" % (roi['from']) - todata = "%6g" % (roi['to']) - if 'rawcounts' in roi: - rawcounts = "%6g" % (roi['rawcounts']) - else: - rawcounts = " ?????? " - if 'netcounts' in roi: - netcounts = "%6g" % (roi['netcounts']) - else: - netcounts = " ?????? " - fields = [ROI, roitype, fromdata, todata, rawcounts, netcounts] - col = 0 - for field in fields: - key2 = self.item(line, col) - if key2 is None: - key2 = qt.QTableWidgetItem(field, - qt.QTableWidgetItem.Type) - self.setItem(line, col, key2) - else: - key2.setText(field) - if (ROI.upper() == 'ICR') or (ROI.upper() == 'DEFAULT'): - key2.setFlags(qt.Qt.ItemIsSelectable | - qt.Qt.ItemIsEnabled) - else: - if col in [0, 2, 3]: - key2.setFlags(qt.Qt.ItemIsSelectable | - qt.Qt.ItemIsEnabled | - qt.Qt.ItemIsEditable) - else: - key2.setFlags(qt.Qt.ItemIsSelectable | - qt.Qt.ItemIsEnabled) - col = col + 1 - self.setRowCount(line0) - i = 0 - for _label in self.labels: - self.resizeColumnToContents(i) - i = i + 1 - self.sortByColumn(2, qt.Qt.AscendingOrder) - for i in range(len(self.roilist)): - key = str(self.item(i, 0).text()) - self.roilist[i] = key - self.roidict[key]['line'] = i - if len(self.roilist) == 1: - self.selectRow(0) + if roidict is not None: + self.setRois(roidict) else: - if currentroi in self.roidict.keys(): - self.selectRow(self.roidict[currentroi]['line']) - _logger.debug("Qt4 ensureCellVisible to be implemented") - self.building = False + self.setRois(roilist) + if currentroi: + self.setActiveRoi(currentroi) - def getROIListAndDict(self): - """Return the currently defined ROIs, as a 2-tuple - ``(roiList, roiDict)`` - ``roiList`` is a list of ROI names. - ``roiDict`` is a dictionary of ROI info. +_indexNextROI = 0 - The ROI names must match an existing dictionary key. - The name list is used to provide an order for the ROIs. - The dictionary's values are sub-dictionaries containing 3 - fields: +class ROI(qt.QObject): + """The Region Of Interest is defined by: - - ``"from"``: x coordinate of the left limit, as a float - - ``"to"``: x coordinate of the right limit, as a float - - ``"type"``: type of ROI, as a string (e.g "channels", "energy") + - A name + - A type. The type is the label of the x axis. This can be used to apply or + not some ROI to a curve and do some post processing. + - The x coordinate of the left limit (fromdata) + - The x coordinate of the right limit (todata) + :param str: name of the ROI + :param fromdata: left limit of the roi + :param todata: right limit of the roi + :param type: type of the ROI + """ + + sigChanged = qt.Signal() + """Signal emitted when the ROI is edited""" + + def __init__(self, name, fromdata=None, todata=None, type_=None): + qt.QObject.__init__(self) + assert type(name) is str + global _indexNextROI + self._id = _indexNextROI + _indexNextROI += 1 + + self._name = name + self._fromdata = fromdata + self._todata = todata + self._type = type_ or 'Default' - :return: ordered dict as a tuple of (list of ROI names, dict of info) + def getID(self): """ - return self.roilist, self.roidict - def _cellClickedSlot(self, *var, **kw): - # selection changed event, get the current selection - row = self.currentRow() - col = self.currentColumn() - if row >= 0 and row < len(self.roilist): - item = self.item(row, 0) - text = '' if item is None else str(item.text()) - self.roilist[row] = text - self._emitSelectionChangedSignal(row, col) + :return int: the unique ID of the ROI + """ + return self._id - def _rowChangedSlot(self, row): - self._emitSelectionChangedSignal(row, 0) + def setType(self, type_): + """ - def _cellChangedSlot(self, row, col): - _logger.debug("_cellChangedSlot(%d, %d)", row, col) - if self.building: - return - if col == 0: - self.nameSlot(row, col) + :param str type_: + """ + if self._type != type_: + self._type = type_ + self.sigChanged.emit() + + def getType(self): + """ + + :return str: the type of the ROI. + """ + return self._type + + def setName(self, name): + """ + Set the name of the :class:`ROI` + + :param str name: + """ + if self._name != name: + self._name = name + self.sigChanged.emit() + + def getName(self): + """ + + :return str: name of the :class:`ROI` + """ + return self._name + + def setFrom(self, frm): + """ + + :param frm: set x coordinate of the left limit + """ + if self._fromdata != frm: + self._fromdata = frm + self.sigChanged.emit() + + def getFrom(self): + """ + + :return: x coordinate of the left limit + """ + return self._fromdata + + def setTo(self, to): + """ + + :param to: x coordinate of the right limit + """ + if self._todata != to: + self._todata = to + self.sigChanged.emit() + + def getTo(self): + """ + + :return: x coordinate of the right limit + """ + return self._todata + + def getMiddle(self): + """ + + :return: middle position between 'from' and 'to' values + """ + return 0.5 * (self.getFrom() + self.getTo()) + + def toDict(self): + """ + + :return: dict containing the roi parameters + """ + ddict = { + 'type': self._type, + 'name': self._name, + 'from': self._fromdata, + 'to': self._todata, + } + if hasattr(self, '_extraInfo'): + ddict.update(self._extraInfo) + return ddict + + @staticmethod + def _fromDict(dic): + assert 'name' in dic + roi = ROI(name=dic['name']) + roi._extraInfo = {} + for key in dic: + if key == 'from': + roi.setFrom(dic['from']) + elif key == 'to': + roi.setTo(dic['to']) + elif key == 'type': + roi.setType(dic['type']) + else: + roi._extraInfo[key] = dic[key] + + return roi + + def isICR(self): + """ + + :return: True if the ROI is the `ICR` + """ + return self._name == 'ICR' + + def computeRawAndNetCounts(self, curve): + """Compute the Raw and net counts in the ROI for the given curve. + + - Raw count: Points values sum of the curve in the defined Region Of + Interest. + + .. image:: img/rawCounts.png + + - Net count: Raw counts minus background + + .. image:: img/netCounts.png + + :param CurveItem curve: + :return tuple: rawCount, netCount + """ + assert isinstance(curve, Curve) or curve is None + + if curve is None: + return None, None + + x = curve.getXData(copy=False) + y = curve.getYData(copy=False) + + idx = numpy.nonzero((self._fromdata <= x) & + (x <= self._todata))[0] + if len(idx): + xw = x[idx] + yw = y[idx] + rawCounts = yw.sum(dtype=numpy.float) + deltaX = xw[-1] - xw[0] + deltaY = yw[-1] - yw[0] + if deltaX > 0.0: + slope = (deltaY / deltaX) + background = yw[0] + slope * (xw - xw[0]) + netCounts = (rawCounts - + background.sum(dtype=numpy.float)) + else: + netCounts = 0.0 else: - self._valueChanged(row, col) + rawCounts = 0.0 + netCounts = 0.0 + return rawCounts, netCounts + + def computeRawAndNetArea(self, curve): + """Compute the Raw and net counts in the ROI for the given curve. + + - Raw area: integral of the curve between the min ROI point and the + max ROI point to the y = 0 line. - def _valueChanged(self, row, col): - if col not in [2, 3]: + .. image:: img/rawArea.png + + - Net area: Raw counts minus background + + .. image:: img/netArea.png + + :param CurveItem curve: + :return tuple: rawArea, netArea + """ + assert isinstance(curve, Curve) or curve is None + + if curve is None: + return None, None + + x = curve.getXData(copy=False) + y = curve.getYData(copy=False) + + y = y[(x >= self._fromdata) & (x <= self._todata)] + x = x[(x >= self._fromdata) & (x <= self._todata)] + + if x.size is 0: + return 0.0, 0.0 + + rawArea = numpy.trapz(y, x=x) + # to speed up and avoid an intersection calculation we are taking the + # closest index to the ROI + closestXLeftIndex = (numpy.abs(x - self.getFrom())).argmin() + closestXRightIndex = (numpy.abs(x - self.getTo())).argmin() + yBackground = y[closestXLeftIndex], y[closestXRightIndex] + background = numpy.trapz(yBackground, x=x) + netArea = rawArea - background + return rawArea, netArea + + +class _RoiMarkerManager(object): + """ + Deal with all the ROI markers + """ + def __init__(self): + self._roiMarkerHandlers = {} + self._middleROIMarkerFlag = False + self._showAllMarkers = False + self._activeRoi = None + + def setActiveRoi(self, roi): + self._activeRoi = roi + self.updateAllMarkers() + + def setShowAllMarkers(self, show): + if show != self._showAllMarkers: + self._showAllMarkers = show + self.updateAllMarkers() + + def add(self, roi, markersHandler): + assert isinstance(roi, ROI) + assert isinstance(markersHandler, _RoiMarkerHandler) + if roi.getID() in self._roiMarkerHandlers: + raise ValueError('roi with the same ID already existing') + else: + self._roiMarkerHandlers[roi.getID()] = markersHandler + + def getMarkerHandler(self, roiID): + if roiID in self._roiMarkerHandlers: + return self._roiMarkerHandlers[roiID] + else: + return None + + def clear(self): + roisHandler = list(self._roiMarkerHandlers.values()) + for roiHandler in roisHandler: + self.remove(roiHandler.roi) + + def remove(self, roi): + if roi is None: return - item = self.item(row, col) - if item is None: + assert isinstance(roi, ROI) + if roi.getID() in self._roiMarkerHandlers: + self._roiMarkerHandlers[roi.getID()].clear() + del self._roiMarkerHandlers[roi.getID()] + + def hasMarker(self, markerID): + assert type(markerID) is str + return self.getMarker(markerID) is not None + + def changePosition(self, markerID, x): + markerHandler = self.getMarker(markerID) + if markerHandler is None: + raise ValueError('Marker %s not register' % markerID) + markerHandler.changePosition(markerID=markerID, x=x) + + def updateMarker(self, markerID): + markerHandler = self.getMarker(markerID) + if markerHandler is None: + raise ValueError('Marker %s not register' % markerID) + roiID = self.getRoiID(markerID) + visible = (self._activeRoi and self._activeRoi.getID() == roiID) or self._showAllMarkers is True + markerHandler.setVisible(visible) + markerHandler.updateAllMarkers() + + def updateRoiMarkers(self, roiID): + if roiID in self._roiMarkerHandlers: + visible = ((self._activeRoi and self._activeRoi.getID() == roiID) + or self._showAllMarkers is True) + _roi = self._roiMarkerHandlers[roiID]._roi() + if _roi and not _roi.isICR(): + self._roiMarkerHandlers[roiID].showMiddleMarker(self._middleROIMarkerFlag) + self._roiMarkerHandlers[roiID].setVisible(visible) + self._roiMarkerHandlers[roiID].updateMarkers() + + def getMarker(self, markerID): + assert type(markerID) is str + for marker in list(self._roiMarkerHandlers.values()): + if marker.hasMarker(markerID): + return marker + + def updateMarkers(self): + for markerHandler in list(self._roiMarkerHandlers.values()): + markerHandler.updateMarkers() + + def getRoiID(self, markerID): + for roiID, markerHandler in self._roiMarkerHandlers.items(): + if markerHandler.hasMarker(markerID): + return roiID + return None + + def setShowMiddleMarkers(self, show): + self._middleROIMarkerFlag = show + self._roiMarkerHandlers.updateAllMarkers() + + def updateAllMarkers(self): + for roiID in self._roiMarkerHandlers: + self.updateRoiMarkers(roiID) + + def getVisibleRois(self): + res = {} + for roiID, roiHandler in self._roiMarkerHandlers.items(): + markers = (roiHandler.getMarker('min'), roiHandler.getMarker('max'), + roiHandler.getMarker('middle')) + for marker in markers: + if marker.isVisible(): + if roiID not in res: + res[roiID] = [] + res[roiID].append(marker) + return res + + +class _RoiMarkerHandler(object): + """Used to deal with ROI markers used in ROITable""" + def __init__(self, roi, plot): + assert roi and isinstance(roi, ROI) + assert plot + + self._roi = weakref.ref(roi) + self._plot = weakref.ref(plot) + self._draggable = False if roi.isICR() else True + self._color = 'black' if roi.isICR() else 'blue' + self._displayMidMarker = False + self._visible = True + + @property + def draggable(self): + return self._draggable + + @property + def plot(self): + return self._plot() + + def clear(self): + if self.plot and self.roi: + self.plot.removeMarker(self._markerID('min')) + self.plot.removeMarker(self._markerID('max')) + self.plot.removeMarker(self._markerID('middle')) + + @property + def roi(self): + return self._roi() + + def setVisible(self, visible): + if visible != self._visible: + self._visible = visible + self.updateMarkers() + + def showMiddleMarker(self, visible): + if self.draggable is False and visible is True: + _logger.warning("ROI is not draggable. Won't display middle marker") return - text = str(item.text()) - try: - value = float(text) - except: + self._displayMidMarker = visible + self.getMarker('middle').setVisible(self._displayMidMarker) + + def updateMarkers(self): + if self.roi is None: return - if row >= len(self.roilist): - _logger.debug("deleting???") + self._updateMinMarkerPos() + self._updateMaxMarkerPos() + self._updateMiddleMarkerPos() + + def _updateMinMarkerPos(self): + self.getMarker('min').setPosition(x=self.roi.getFrom(), y=None) + self.getMarker('min').setVisible(self._visible) + + def _updateMaxMarkerPos(self): + self.getMarker('max').setPosition(x=self.roi.getTo(), y=None) + self.getMarker('max').setVisible(self._visible) + + def _updateMiddleMarkerPos(self): + self.getMarker('middle').setPosition(x=self.roi.getMiddle(), y=None) + self.getMarker('middle').setVisible(self._displayMidMarker and self._visible) + + def getMarker(self, markerType): + if self.plot is None: + return None + assert markerType in ('min', 'max', 'middle') + if self.plot._getMarker(self._markerID(markerType)) is None: + assert self.roi + if markerType == 'min': + val = self.roi.getFrom() + elif markerType == 'max': + val = self.roi.getTo() + else: + val = self.roi.getMiddle() + + _color = self._color + if markerType == 'middle': + _color = 'yellow' + self.plot.addXMarker(val, + legend=self._markerID(markerType), + text=self.getMarkerName(markerType), + color=_color, + draggable=self.draggable) + return self.plot._getMarker(self._markerID(markerType)) + + def _markerID(self, markerType): + assert markerType in ('min', 'max', 'middle') + assert self.roi + return '_'.join((str(self.roi.getID()), markerType)) + + def getMarkerName(self, markerType): + assert markerType in ('min', 'max', 'middle') + assert self.roi + return ' '.join((self.roi.getName(), markerType)) + + def updateTexts(self): + self.getMarker('min').setText(self.getMarkerName('min')) + self.getMarker('max').setText(self.getMarkerName('max')) + self.getMarker('middle').setText(self.getMarkerName('middle')) + + def changePosition(self, markerID, x): + assert self.hasMarker(markerID) + markerType = self._getMarkerType(markerID) + assert markerType is not None + if self.roi is None: return - item = self.item(row, 0) - if item is None: - text = "" + if markerType == 'min': + self.roi.setFrom(x) + self._updateMiddleMarkerPos() + elif markerType == 'max': + self.roi.setTo(x) + self._updateMiddleMarkerPos() else: - text = str(item.text()) - if not len(text): - return - if col == 2: - self.roidict[text]['from'] = value - elif col == 3: - self.roidict[text]['to'] = value - self._emitSelectionChangedSignal(row, col) - - def nameSlot(self, row, col): - if col != 0: - return - if row >= len(self.roilist): - _logger.debug("deleting???") - return - item = self.item(row, col) - if item is None: - text = "" + delta = x - 0.5 * (self.roi.getFrom() + self.roi.getTo()) + self.roi.setFrom(self.roi.getFrom() + delta) + self.roi.setTo(self.roi.getTo() + delta) + self._updateMinMarkerPos() + self._updateMaxMarkerPos() + + def hasMarker(self, marker): + return marker in (self._markerID('min'), + self._markerID('max'), + self._markerID('middle')) + + def _getMarkerType(self, markerID): + if markerID.endswith('_min'): + return 'min' + elif markerID.endswith('_max'): + return 'max' + elif markerID.endswith('_middle'): + return 'middle' else: - text = str(item.text()) - if len(text) and (text not in self.roilist): - old = self.roilist[row] - self.roilist[row] = text - self.roidict[text] = {} - self.roidict[text].update(self.roidict[old]) - del self.roidict[old] - self._emitSelectionChangedSignal(row, col) - - def _emitSelectionChangedSignal(self, row, col): - ddict = {} - ddict['event'] = "selectionChanged" - ddict['row'] = row - ddict['col'] = col - ddict['roi'] = self.roidict[self.roilist[row]] - ddict['key'] = self.roilist[row] - ddict['colheader'] = self.labels[col] - ddict['rowheader'] = "%d" % row - self.sigROITableSignal.emit(ddict) + return None class CurvesROIDockWidget(qt.QDockWidget): @@ -1007,6 +1520,8 @@ class CurvesROIDockWidget(qt.QDockWidget): def __init__(self, parent=None, plot=None, name=None): super(CurvesROIDockWidget, self).__init__(name, parent) + assert plot is not None + self.plot = plot self.roiWidget = CurvesROIWidget(self, name, plot=plot) """Main widget of type :class:`CurvesROIWidget`""" @@ -1016,12 +1531,15 @@ class CurvesROIDockWidget(qt.QDockWidget): self.calculateROIs = self.calculateRois = self.roiWidget.calculateRois self.setRois = self.roiWidget.setRois self.getRois = self.roiWidget.getRois + self.roiWidget.sigROISignal.connect(self._forwardSigROISignal) - self.currentROI = self.roiWidget.currentROI self.layout().setContentsMargins(0, 0, 0, 0) self.setWidget(self.roiWidget) + self.setAreaVisible = self.roiWidget.roiTable.setAreaVisible + self.setCountsVisible = self.roiWidget.roiTable.setCountsVisible + def _forwardSigROISignal(self, ddict): # emit deprecated signal for backward compatibility (silx < 0.7) self.sigROISignal.emit(ddict) @@ -1042,3 +1560,7 @@ class CurvesROIDockWidget(qt.QDockWidget): """ self.raise_() qt.QDockWidget.showEvent(self, event) + + @property + def currentROI(self): + return self.roiWidget.currentRoi diff --git a/silx/gui/plot/MaskToolsWidget.py b/silx/gui/plot/MaskToolsWidget.py index 990e479..9d727e7 100644 --- a/silx/gui/plot/MaskToolsWidget.py +++ b/silx/gui/plot/MaskToolsWidget.py @@ -35,7 +35,7 @@ from __future__ import division __authors__ = ["T. Vincent", "P. Knobel"] __license__ = "MIT" -__date__ = "29/08/2018" +__date__ = "15/02/2019" import os @@ -57,10 +57,7 @@ from .. import qt from silx.third_party.EdfFile import EdfFile from silx.third_party.TiffIO import TiffIO -try: - import fabio -except ImportError: - fabio = None +import fabio _logger = logging.getLogger(__name__) @@ -135,8 +132,6 @@ class ImageMask(BaseMask): self._saveToHdf5(filename, self.getMask(copy=False)) elif kind == 'msk': - if fabio is None: - raise ImportError("Fit2d mask files can't be written: Fabio module is not available") try: data = self.getMask(copy=False) image = fabio.fabioimage.FabioImage(data=data) @@ -250,6 +245,19 @@ class ImageMask(BaseMask): rows, cols = shapes.circle_fill(crow, ccol, radius) self.updatePoints(level, rows, cols, mask) + def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True): + """Mask/Unmask an ellipse of the given mask level. + + :param int level: Mask level to update. + :param int crow: Row of the center of the ellipse + :param int ccol: Column of the center of the ellipse + :param float radius_r: Radius of the ellipse in the row + :param float radius_c: Radius of the ellipse in the column + :param bool mask: True to mask (default), False to unmask. + """ + rows, cols = shapes.ellipse_fill(crow, ccol, radius_r, radius_c) + self.updatePoints(level, rows, cols, mask) + def updateLine(self, level, row0, col0, row1, col1, width, mask=True): """Mask/Unmask a line of the given mask level. @@ -300,6 +308,10 @@ class MaskToolsWidget(BaseMaskToolsWidget): _logger.error('Not an image, shape: %d', len(mask.shape)) return None + # Handle mask with single level + if self.multipleMasks() == 'single': + mask = numpy.array(mask != 0, dtype=numpy.uint8) + # if mask has not changed, do nothing if numpy.array_equal(mask, self.getSelectionMask()): return mask.shape @@ -501,8 +513,6 @@ class MaskToolsWidget(BaseMaskToolsWidget): _logger.debug("Backtrace", exc_info=True) raise e elif extension == "msk": - if fabio is None: - raise ImportError("Fit2d mask files can't be read: Fabio module is not available") try: mask = fabio.open(filename).data except Exception as e: @@ -682,41 +692,51 @@ class MaskToolsWidget(BaseMaskToolsWidget): level = self.levelSpinBox.value() - if (self._drawingMode == 'rectangle' and - event['event'] == 'drawingFinished'): - # Convert from plot to array coords - doMask = self._isMasking() - ox, oy = self._origin - sx, sy = self._scale - - height = int(abs(event['height'] / sy)) - width = int(abs(event['width'] / sx)) - - row = int((event['y'] - oy) / sy) - if sy < 0: - row -= height - - col = int((event['x'] - ox) / sx) - if sx < 0: - col -= width - - self._mask.updateRectangle( - level, - row=row, - col=col, - height=height, - width=width, - mask=doMask) - self._mask.commit() + if self._drawingMode == 'rectangle': + if event['event'] == 'drawingFinished': + # Convert from plot to array coords + doMask = self._isMasking() + ox, oy = self._origin + sx, sy = self._scale + + height = int(abs(event['height'] / sy)) + width = int(abs(event['width'] / sx)) + + row = int((event['y'] - oy) / sy) + if sy < 0: + row -= height + + col = int((event['x'] - ox) / sx) + if sx < 0: + col -= width + + self._mask.updateRectangle( + level, + row=row, + col=col, + height=height, + width=width, + mask=doMask) + self._mask.commit() - elif (self._drawingMode == 'polygon' and - event['event'] == 'drawingFinished'): - doMask = self._isMasking() - # Convert from plot to array coords - vertices = (event['points'] - self._origin) / self._scale - vertices = vertices.astype(numpy.int)[:, (1, 0)] # (row, col) - self._mask.updatePolygon(level, vertices, doMask) - self._mask.commit() + elif self._drawingMode == 'ellipse': + if event['event'] == 'drawingFinished': + doMask = self._isMasking() + # Convert from plot to array coords + center = (event['points'][0] - self._origin) / self._scale + size = event['points'][1] / self._scale + center = center.astype(numpy.int) # (row, col) + self._mask.updateEllipse(level, center[1], center[0], size[1], size[0], doMask) + self._mask.commit() + + elif self._drawingMode == 'polygon': + if event['event'] == 'drawingFinished': + doMask = self._isMasking() + # Convert from plot to array coords + vertices = (event['points'] - self._origin) / self._scale + vertices = vertices.astype(numpy.int)[:, (1, 0)] # (row, col) + self._mask.updatePolygon(level, vertices, doMask) + self._mask.commit() elif self._drawingMode == 'pencil': doMask = self._isMasking() @@ -743,6 +763,8 @@ class MaskToolsWidget(BaseMaskToolsWidget): self._lastPencilPos = None else: self._lastPencilPos = row, col + else: + _logger.error("Drawing mode %s unsupported", self._drawingMode) def _loadRangeFromColormapTriggered(self): """Set range from active image colormap range""" diff --git a/silx/gui/plot/PlotInteraction.py b/silx/gui/plot/PlotInteraction.py index 356bda6..27abd10 100644 --- a/silx/gui/plot/PlotInteraction.py +++ b/silx/gui/plot/PlotInteraction.py @@ -26,7 +26,7 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "24/04/2018" +__date__ = "15/02/2019" import math @@ -96,10 +96,18 @@ class _PlotInteraction(object): fill = fill != 'none' # TODO not very nice either + greyed = colors.greyed(color)[0] + if greyed < 0.5: + color2 = "white" + else: + color2 = "black" + self.plot.addItem(points[:, 0], points[:, 1], legend=legend, replace=False, - shape=shape, color=color, fill=fill, + shape=shape, fill=fill, + color=color, linebgcolor=color2, linestyle="--", overlay=True) + self._selectionAreas.add(legend) def resetSelectionArea(self): @@ -274,6 +282,8 @@ class Zoom(_ZoomOnWheel): and zoom on mouse wheel. """ + SURFACE_THRESHOLD = 5 + def __init__(self, plot, color): self.color = color @@ -347,35 +357,44 @@ class Zoom(_ZoomOnWheel): self.setSelectionArea(corners, fill='none', color=self.color) - def endDrag(self, startPos, endPos): - x0, y0 = startPos - x1, y1 = endPos + def _zoom(self, x0, y0, x1, y1): + """Zoom to the rectangle view x0,y0 x1,y1. + """ + startPos = x0, y0 + endPos = x1, y1 + + # Store current zoom state in stack + self.plot.getLimitsHistory().push() - if x0 != x1 or y0 != y1: # Avoid empty zoom area - # Store current zoom state in stack - self.plot.getLimitsHistory().push() + if self.plot.isKeepDataAspectRatio(): + x0, y0, x1, y1 = self._areaWithAspectRatio(x0, y0, x1, y1) + + # Convert to data space and set limits + x0, y0 = self.plot.pixelToData(x0, y0, check=False) - if self.plot.isKeepDataAspectRatio(): - x0, y0, x1, y1 = self._areaWithAspectRatio(x0, y0, x1, y1) + dataPos = self.plot.pixelToData( + startPos[0], startPos[1], axis="right", check=False) + y2_0 = dataPos[1] - # Convert to data space and set limits - x0, y0 = self.plot.pixelToData(x0, y0, check=False) + x1, y1 = self.plot.pixelToData(x1, y1, check=False) - dataPos = self.plot.pixelToData( - startPos[0], startPos[1], axis="right", check=False) - y2_0 = dataPos[1] + dataPos = self.plot.pixelToData( + endPos[0], endPos[1], axis="right", check=False) + y2_1 = dataPos[1] - x1, y1 = self.plot.pixelToData(x1, y1, check=False) + xMin, xMax = min(x0, x1), max(x0, x1) + yMin, yMax = min(y0, y1), max(y0, y1) + y2Min, y2Max = min(y2_0, y2_1), max(y2_0, y2_1) - dataPos = self.plot.pixelToData( - endPos[0], endPos[1], axis="right", check=False) - y2_1 = dataPos[1] + self.plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max) - xMin, xMax = min(x0, x1), max(x0, x1) - yMin, yMax = min(y0, y1), max(y0, y1) - y2Min, y2Max = min(y2_0, y2_1), max(y2_0, y2_1) + def endDrag(self, startPos, endPos): + x0, y0 = startPos + x1, y1 = endPos - self.plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max) + if abs(x0 - x1) * abs(y0 - y1) >= self.SURFACE_THRESHOLD: + # Avoid empty zoom area + self._zoom(x0, y0, x1, y1) self.resetSelectionArea() @@ -544,7 +563,6 @@ class SelectPolygon(Select): return self.DRAG_THRESHOLD_DIST * ratio - class Select2Points(Select): """Base class for drawing selection based on 2 input points.""" class Idle(State): @@ -603,6 +621,87 @@ class Select2Points(Select): self.cancelSelect() +class SelectEllipse(Select2Points): + """Drawing ellipse selection area state machine.""" + def beginSelect(self, x, y): + self.center = self.plot.pixelToData(x, y) + assert self.center is not None + + def _getEllipseSize(self, pointInEllipse): + """ + Returns the size from the center to the bounding box of the ellipse. + + :param Tuple[float,float] pointInEllipse: A point of the ellipse + :rtype: Tuple[float,float] + """ + x = abs(self.center[0] - pointInEllipse[0]) + y = abs(self.center[1] - pointInEllipse[1]) + if x == 0 or y == 0: + return x, y + # Ellipse definitions + # e: eccentricity + # a: length fron center to bounding box width + # b: length fron center to bounding box height + # Equations + # (1) b < a + # (2) For x,y a point in the ellipse: x^2/a^2 + y^2/b^2 = 1 + # (3) b = a * sqrt(1-e^2) + # (4) e = sqrt(a^2 - b^2) / a + + # The eccentricity of the ellipse defined by a,b=x,y is the same + # as the one we are searching for. + swap = x < y + if swap: + x, y = y, x + e = math.sqrt(x**2 - y**2) / x + # From (2) using (3) to replace b + # a^2 = x^2 + y^2 / (1-e^2) + a = math.sqrt(x**2 + y**2 / (1.0 - e**2)) + b = a * math.sqrt(1 - e**2) + if swap: + a, b = b, a + return a, b + + def select(self, x, y): + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + width, height = self._getEllipseSize(dataPos) + + # Circle used for circle preview + nbpoints = 27. + angles = numpy.arange(nbpoints) * numpy.pi * 2.0 / nbpoints + circleShape = numpy.array((numpy.cos(angles) * width, + numpy.sin(angles) * height)).T + circleShape += numpy.array(self.center) + + self.setSelectionArea(circleShape, + shape="polygon", + fill='hatch', + color=self.color) + + eventDict = prepareDrawingSignal('drawingProgress', + 'ellipse', + (self.center, (width, height)), + self.parameters) + self.plot.notify(**eventDict) + + def endSelect(self, x, y): + self.resetSelectionArea() + + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + width, height = self._getEllipseSize(dataPos) + + eventDict = prepareDrawingSignal('drawingFinished', + 'ellipse', + (self.center, (width, height)), + self.parameters) + self.plot.notify(**eventDict) + + def cancelSelect(self): + self.resetSelectionArea() + + class SelectRectangle(Select2Points): """Drawing rectangle selection area state machine.""" def beginSelect(self, x, y): @@ -1488,6 +1587,7 @@ class PlotInteraction(object): _DRAW_MODES = { 'polygon': SelectPolygon, 'rectangle': SelectRectangle, + 'ellipse': SelectEllipse, 'line': SelectLine, 'vline': SelectVLine, 'hline': SelectHLine, diff --git a/silx/gui/plot/PlotToolButtons.py b/silx/gui/plot/PlotToolButtons.py index f6291b5..bf6b8ce 100644 --- a/silx/gui/plot/PlotToolButtons.py +++ b/silx/gui/plot/PlotToolButtons.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2018 European Synchrotron Radiation Facility +# Copyright (c) 2004-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -45,6 +45,7 @@ import weakref from .. import icons from .. import qt +from ... import config from .items import SymbolMixIn @@ -250,24 +251,24 @@ class ProfileOptionToolButton(PlotToolButton): self.STATE = {} # is down self.STATE['sum', "icon"] = icons.getQIcon('math-sigma') - self.STATE['sum', "state"] = "compute profile sum" - self.STATE['sum', "action"] = "compute profile sum" + self.STATE['sum', "state"] = "Compute profile sum" + self.STATE['sum', "action"] = "Compute profile sum" # keep ration self.STATE['mean', "icon"] = icons.getQIcon('math-mean') - self.STATE['mean', "state"] = "compute profile mean" - self.STATE['mean', "action"] = "compute profile mean" + self.STATE['mean', "state"] = "Compute profile mean" + self.STATE['mean', "action"] = "Compute profile mean" - sumAction = self._createAction('sum') - sumAction.triggered.connect(self.setSum) - sumAction.setIconVisibleInMenu(True) + self.sumAction = self._createAction('sum') + self.sumAction.triggered.connect(self.setSum) + self.sumAction.setIconVisibleInMenu(True) - meanAction = self._createAction('mean') - meanAction.triggered.connect(self.setMean) - meanAction.setIconVisibleInMenu(True) + self.meanAction = self._createAction('mean') + self.meanAction.triggered.connect(self.setMean) + self.meanAction.setIconVisibleInMenu(True) menu = qt.QMenu(self) - menu.addAction(sumAction) - menu.addAction(meanAction) + menu.addAction(self.sumAction) + menu.addAction(self.meanAction) self.setMenu(menu) self.setPopupMode(qt.QToolButton.InstantPopup) self.setMean() @@ -370,7 +371,7 @@ class SymbolToolButton(PlotToolButton): slider = qt.QSlider(qt.Qt.Horizontal) slider.setRange(1, 20) - slider.setValue(SymbolMixIn._DEFAULT_SYMBOL_SIZE) + slider.setValue(config.DEFAULT_PLOT_SYMBOL_SIZE) slider.setTracking(False) slider.valueChanged.connect(self._sizeChanged) widgetAction = qt.QWidgetAction(menu) diff --git a/silx/gui/plot/PlotWidget.py b/silx/gui/plot/PlotWidget.py index e023a21..cfe39fa 100644 --- a/silx/gui/plot/PlotWidget.py +++ b/silx/gui/plot/PlotWidget.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2018 European Synchrotron Radiation Facility +# Copyright (c) 2004-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -31,7 +31,7 @@ from __future__ import division __authors__ = ["V.A. Sole", "T. Vincent"] __license__ = "MIT" -__date__ = "12/10/2018" +__date__ = "21/12/2018" from collections import OrderedDict, namedtuple @@ -44,7 +44,6 @@ import numpy import silx from silx.utils.weakref import WeakMethodProxy -from silx.utils import deprecation from silx.utils.property import classproperty from silx.utils.deprecation import deprecated # Import matplotlib backend here to init matplotlib our way @@ -99,7 +98,7 @@ class PlotWidget(qt.QMainWindow): # TODO: Can be removed for silx 0.10 @classproperty - @deprecation.deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2) + @deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2) def DEFAULT_BACKEND(self): """Class attribute setting the default backend for all instances.""" return silx.config.DEFAULT_PLOT_BACKEND @@ -193,21 +192,12 @@ class PlotWidget(qt.QMainWindow): It provides the visible state. """ - def __init__(self, parent=None, backend=None, - legends=False, callback=None, **kw): + def __init__(self, parent=None, backend=None): self._autoreplot = False self._dirty = False self._cursorInPlot = False self.__muteActiveItemChanged = False - if kw: - _logger.warning( - 'deprecated: __init__ extra arguments: %s', str(kw)) - if legends: - _logger.warning('deprecated: __init__ legend argument') - if callback: - _logger.warning('deprecated: __init__ callback argument') - self._panWithArrowKeys = True self._viewConstrains = None @@ -218,27 +208,8 @@ class PlotWidget(qt.QMainWindow): else: self.setWindowTitle('PlotWidget') - if backend is None: - backend = silx.config.DEFAULT_PLOT_BACKEND - - if hasattr(backend, "__call__"): - self._backend = backend(self, parent) - - elif hasattr(backend, "lower"): - lowerCaseString = backend.lower() - if lowerCaseString in ("matplotlib", "mpl"): - backendClass = BackendMatplotlibQt - elif lowerCaseString in ('gl', 'opengl'): - from .backends.BackendOpenGL import BackendOpenGL - backendClass = BackendOpenGL - elif lowerCaseString == 'none': - from .backends.BackendBase import BackendBase as backendClass - else: - raise ValueError("Backend not supported %s" % backend) - self._backend = backendClass(self, parent) - - else: - raise ValueError("Backend not supported %s" % str(backend)) + self._backend = None + self._setBackend(backend) self.setCallback() # set _callback @@ -258,6 +229,12 @@ class PlotWidget(qt.QMainWindow): self._activeLegend = {'curve': None, 'image': None, 'scatter': None} + # plot colors (updated later to sync backend) + self._foregroundColor = 0., 0., 0., 1. + self._gridColor = .7, .7, .7, 1. + self._backgroundColor = 1., 1., 1., 1. + self._dataBackgroundColor = None + # default properties self._cursorConfiguration = None @@ -275,7 +252,7 @@ class PlotWidget(qt.QMainWindow): self.setDefaultColormap() # Init default colormap - self.setDefaultPlotPoints(False) + self.setDefaultPlotPoints(silx.config.DEFAULT_PLOT_CURVE_SYMBOL_MODE) self.setDefaultPlotLines(True) self._limitsHistory = LimitsHistory(self) @@ -306,9 +283,41 @@ class PlotWidget(qt.QMainWindow): self.setGraphYLimits(0., 100., axis='right') self.setGraphYLimits(0., 100., axis='left') + # Sync backend colors with default ones + self._foregroundColorsUpdated() + self._backgroundColorsUpdated() + + def _setBackend(self, backend): + """Setup a new backend""" + assert(self._backend is None) + + if backend is None: + backend = silx.config.DEFAULT_PLOT_BACKEND + + if hasattr(backend, "__call__"): + backend = backend(self, self) + + elif hasattr(backend, "lower"): + lowerCaseString = backend.lower() + if lowerCaseString in ("matplotlib", "mpl"): + backendClass = BackendMatplotlibQt + elif lowerCaseString in ('gl', 'opengl'): + from .backends.BackendOpenGL import BackendOpenGL + backendClass = BackendOpenGL + elif lowerCaseString == 'none': + from .backends.BackendBase import BackendBase as backendClass + else: + raise ValueError("Backend not supported %s" % backend) + backend = backendClass(self, self) + + else: + raise ValueError("Backend not supported %s" % str(backend)) + + self._backend = backend + # TODO: Can be removed for silx 0.10 @staticmethod - @deprecation.deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2) + @deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2) def setDefaultBackend(backend): """Set system wide default plot backend. @@ -349,6 +358,119 @@ class PlotWidget(qt.QMainWindow): if self._autoreplot and not wasDirty and self.isVisible(): self._backend.postRedisplay() + def _foregroundColorsUpdated(self): + """Handle change of foreground/grid color""" + if self._gridColor is None: + gridColor = self._foregroundColor + else: + gridColor = self._gridColor + self._backend.setForegroundColors( + self._foregroundColor, gridColor) + self._setDirtyPlot() + + def getForegroundColor(self): + """Returns the RGBA colors used to display the foreground of this widget + + :rtype: qt.QColor + """ + return qt.QColor.fromRgbF(*self._foregroundColor) + + def setForegroundColor(self, color): + """Set the foreground color of this widget. + + :param Union[List[int],List[float],QColor] color: + The new RGB(A) color. + """ + color = colors.rgba(color) + if self._foregroundColor != color: + self._foregroundColor = color + self._foregroundColorsUpdated() + + def getGridColor(self): + """Returns the RGBA colors used to display the grid lines + + An invalid QColor is returned if there is no grid color, + in which case the foreground color is used. + + :rtype: qt.QColor + """ + if self._gridColor is None: + return qt.QColor() # An invalid color + else: + return qt.QColor.fromRgbF(*self._gridColor) + + def setGridColor(self, color): + """Set the grid lines color + + :param Union[List[int],List[float],QColor,None] color: + The new RGB(A) color. + """ + if isinstance(color, qt.QColor) and not color.isValid(): + color = None + if color is not None: + color = colors.rgba(color) + if self._gridColor != color: + self._gridColor = color + self._foregroundColorsUpdated() + + def _backgroundColorsUpdated(self): + """Handle change of background/data background color""" + if self._dataBackgroundColor is None: + dataBGColor = self._backgroundColor + else: + dataBGColor = self._dataBackgroundColor + self._backend.setBackgroundColors( + self._backgroundColor, dataBGColor) + self._setDirtyPlot() + + def getBackgroundColor(self): + """Returns the RGBA colors used to display the background of this widget. + + :rtype: qt.QColor + """ + return qt.QColor.fromRgbF(*self._backgroundColor) + + def setBackgroundColor(self, color): + """Set the background color of this widget. + + :param Union[List[int],List[float],QColor] color: + The new RGB(A) color. + """ + color = colors.rgba(color) + if self._backgroundColor != color: + self._backgroundColor = color + self._backgroundColorsUpdated() + + def getDataBackgroundColor(self): + """Returns the RGBA colors used to display the background of the plot + view displaying the data. + + An invalid QColor is returned if there is no data background color. + + :rtype: qt.QColor + """ + if self._dataBackgroundColor is None: + # An invalid color + return qt.QColor() + else: + return qt.QColor.fromRgbF(*self._dataBackgroundColor) + + def setDataBackgroundColor(self, color): + """Set the background color of this widget. + + Set to None or an invalid QColor to use the background color. + + :param Union[List[int],List[float],QColor,None] color: + The new RGB(A) color. + """ + if isinstance(color, qt.QColor) and not color.isValid(): + color = None + if color is not None: + color = colors.rgba(color) + if self._dataBackgroundColor != color: + self._dataBackgroundColor = color + self._backgroundColorsUpdated() + def showEvent(self, event): if self._autoreplot and self._dirty: self._backend.postRedisplay() @@ -528,13 +650,13 @@ class PlotWidget(qt.QMainWindow): # This value is used when curve is updated either internally or by user. def addCurve(self, x, y, legend=None, info=None, - replace=False, replot=None, + replace=False, color=None, symbol=None, linewidth=None, linestyle=None, xlabel=None, ylabel=None, yaxis=None, xerror=None, yerror=None, z=None, selectable=None, fill=None, resetzoom=True, - histogram=None, copy=True, **kw): + histogram=None, copy=True): """Add a 1D curve given by x an y to the graph. Curves are uniquely identified by their legend. @@ -617,15 +739,6 @@ class PlotWidget(qt.QMainWindow): False to use provided arrays. :returns: The key string identify this curve """ - # Deprecation warnings - if replot is not None: - _logger.warning( - 'addCurve deprecated replot argument, use resetzoom instead') - resetzoom = replot and resetzoom - - if kw: - _logger.warning('addCurve: deprecated extra arguments') - # This is an histogram, use addHistogram if histogram is not None: histoLegend = self.addHistogram(histogram=y, @@ -825,13 +938,13 @@ class PlotWidget(qt.QMainWindow): return legend def addImage(self, data, legend=None, info=None, - replace=False, replot=None, - xScale=None, yScale=None, z=None, + replace=False, + z=None, selectable=None, draggable=None, colormap=None, pixmap=None, xlabel=None, ylabel=None, origin=None, scale=None, - resetzoom=True, copy=True, **kw): + resetzoom=True, copy=True): """Add a 2D dataset or an image to the plot. It displays either an array of data using a colormap or a RGB(A) image. @@ -883,28 +996,6 @@ class PlotWidget(qt.QMainWindow): False to use provided arrays. :returns: The key string identify this image """ - # Deprecation warnings - if xScale is not None or yScale is not None: - _logger.warning( - 'addImage deprecated xScale and yScale arguments,' - 'use origin, scale arguments instead.') - if origin is None and scale is None: - origin = xScale[0], yScale[0] - scale = xScale[1], yScale[1] - else: - _logger.warning( - 'addCurve: xScale, yScale and origin, scale arguments' - ' are conflicting. xScale and yScale are ignored.' - ' Use only origin, scale arguments.') - - if replot is not None: - _logger.warning( - 'addImage deprecated replot argument, use resetzoom instead') - resetzoom = replot and resetzoom - - if kw: - _logger.warning('addImage: deprecated extra arguments') - legend = "Unnamed Image 1.1" if legend is None else str(legend) # Check if image was previously active @@ -1090,7 +1181,8 @@ class PlotWidget(qt.QMainWindow): def addItem(self, xdata, ydata, legend=None, info=None, replace=False, shape="polygon", color='black', fill=True, - overlay=False, z=None, **kw): + overlay=False, z=None, linestyle="-", linewidth=1.0, + linebgcolor=None): """Add an item (i.e. a shape) to the plot. Items are uniquely identified by their legend. @@ -1114,13 +1206,23 @@ class PlotWidget(qt.QMainWindow): This allows for rendering optimization if this item is changed often. :param int z: Layer on which to draw the item (default: 2) + :param str linestyle: Style of the line. + Only relevant for line markers where X or Y is None. + Value in: + + - ' ' no line + - '-' solid line + - '--' dashed line + - '-.' dash-dot line + - ':' dotted line + :param float linewidth: Width of the line. + Only relevant for line markers where X or Y is None. + :param str linebgcolor: Background color of the line, e.g., 'blue', 'b', + '#FF0000'. It is used to draw dotted line using a second color. :returns: The key string identify this item """ # expected to receive the same parameters as the signal - if kw: - _logger.warning('addItem deprecated parameters: %s', str(kw)) - legend = "Unnamed Item 1.1" if legend is None else str(legend) z = int(z) if z is not None else 2 @@ -1138,6 +1240,9 @@ class PlotWidget(qt.QMainWindow): item.setOverlay(overlay) item.setZValue(z) item.setPoints(numpy.array((xdata, ydata)).T) + item.setLineStyle(linestyle) + item.setLineWidth(linewidth) + item.setLineBgColor(linebgcolor) self._add(item) @@ -1148,8 +1253,7 @@ class PlotWidget(qt.QMainWindow): color=None, selectable=False, draggable=False, - constraint=None, - **kw): + constraint=None): """Add a vertical line marker to the plot. Markers are uniquely identified by their legend. @@ -1177,10 +1281,6 @@ class PlotWidget(qt.QMainWindow): and that returns the filtered coordinates. :return: The key string identify this marker """ - if kw: - _logger.warning( - 'addXMarker deprecated extra parameters: %s', str(kw)) - return self._addMarker(x=x, y=None, legend=legend, text=text, color=color, selectable=selectable, draggable=draggable, @@ -1192,8 +1292,7 @@ class PlotWidget(qt.QMainWindow): color=None, selectable=False, draggable=False, - constraint=None, - **kw): + constraint=None): """Add a horizontal line marker to the plot. Markers are uniquely identified by their legend. @@ -1221,10 +1320,6 @@ class PlotWidget(qt.QMainWindow): and that returns the filtered coordinates. :return: The key string identify this marker """ - if kw: - _logger.warning( - 'addYMarker deprecated extra parameters: %s', str(kw)) - return self._addMarker(x=None, y=y, legend=legend, text=text, color=color, selectable=selectable, draggable=draggable, @@ -1236,8 +1331,7 @@ class PlotWidget(qt.QMainWindow): selectable=False, draggable=False, symbol='+', - constraint=None, - **kw): + constraint=None): """Add a point marker to the plot. Markers are uniquely identified by their legend. @@ -1277,10 +1371,6 @@ class PlotWidget(qt.QMainWindow): and that returns the filtered coordinates. :return: The key string identify this marker """ - if kw: - _logger.warning( - 'addMarker deprecated extra parameters: %s', str(kw)) - if x is None: xmin, xmax = self._xAxis.getLimits() x = 0.5 * (xmax + xmin) @@ -1368,7 +1458,7 @@ class PlotWidget(qt.QMainWindow): curve = self._getItem('curve', legend) return curve is not None and not curve.isVisible() - def hideCurve(self, legend, flag=True, replot=None): + def hideCurve(self, legend, flag=True): """Show/Hide the curve associated to legend. Even when hidden, the curve is kept in the list of curves. @@ -1376,9 +1466,6 @@ class PlotWidget(qt.QMainWindow): :param str legend: The legend associated to the curve to be hidden :param bool flag: True (default) to hide the curve, False to show it """ - if replot is not None: - _logger.warning('hideCurve deprecated replot parameter') - curve = self._getItem('curve', legend) if curve is None: _logger.warning('Curve not in plot: %s', legend) @@ -1660,16 +1747,13 @@ class PlotWidget(qt.QMainWindow): return self._getActiveItem(kind='curve', just_legend=just_legend) - def setActiveCurve(self, legend, replot=None): + def setActiveCurve(self, legend): """Make the curve associated to legend the active curve. :param legend: The legend associated to the curve or None to have no active curve. :type legend: str or None """ - if replot is not None: - _logger.warning('setActiveCurve deprecated replot parameter') - if not self.isActiveCurveHandling(): return if legend is None and self.getActiveCurveSelectionMode() == "legacy": @@ -1723,15 +1807,12 @@ class PlotWidget(qt.QMainWindow): """ return self._getActiveItem(kind='image', just_legend=just_legend) - def setActiveImage(self, legend, replot=None): + def setActiveImage(self, legend): """Make the image associated to legend the active image. :param str legend: The legend associated to the image or None to have no active image. """ - if replot is not None: - _logger.warning('setActiveImage deprecated replot parameter') - return self._setActiveItem(kind='image', legend=legend) def _getActiveItem(self, kind, just_legend=False): @@ -2028,14 +2109,12 @@ class PlotWidget(qt.QMainWindow): """ return self._backend.getGraphXLimits() - def setGraphXLimits(self, xmin, xmax, replot=None): + def setGraphXLimits(self, xmin, xmax): """Set the graph X (bottom) limits. :param float xmin: minimum bottom axis value :param float xmax: maximum bottom axis value """ - if replot is not None: - _logger.warning('setGraphXLimits deprecated replot parameter') self._xAxis.setLimits(xmin, xmax) def getGraphYLimits(self, axis='left'): @@ -2049,7 +2128,7 @@ class PlotWidget(qt.QMainWindow): yAxis = self._yAxis if axis == 'left' else self._yRightAxis return yAxis.getLimits() - def setGraphYLimits(self, ymin, ymax, axis='left', replot=None): + def setGraphYLimits(self, ymin, ymax, axis='left'): """Set the graph Y limits. :param float ymin: minimum bottom axis value @@ -2057,8 +2136,6 @@ class PlotWidget(qt.QMainWindow): :param str axis: The axis for which to get the limits: Either 'left' or 'right' """ - if replot is not None: - _logger.warning('setGraphYLimits deprecated replot parameter') assert axis in ('left', 'right') yAxis = self._yAxis if axis == 'left' else self._yRightAxis return yAxis.setLimits(ymin, ymax) @@ -2192,36 +2269,6 @@ class PlotWidget(qt.QMainWindow): def _isAxesDisplayed(self): return self._backend.isAxesDisplayed() - @property - @deprecated(since_version='0.6') - def sigSetYAxisInverted(self): - """Signal emitted when Y axis orientation has changed""" - return self._yAxis.sigInvertedChanged - - @property - @deprecated(since_version='0.6') - def sigSetXAxisLogarithmic(self): - """Signal emitted when X axis scale has changed""" - return self._xAxis._sigLogarithmicChanged - - @property - @deprecated(since_version='0.6') - def sigSetYAxisLogarithmic(self): - """Signal emitted when Y axis scale has changed""" - return self._yAxis._sigLogarithmicChanged - - @property - @deprecated(since_version='0.6') - def sigSetXAxisAutoScale(self): - """Signal emitted when X axis autoscale has changed""" - return self._xAxis.sigAutoScaleChanged - - @property - @deprecated(since_version='0.6') - def sigSetYAxisAutoScale(self): - """Signal emitted when Y axis autoscale has changed""" - return self._yAxis.sigAutoScaleChanged - def setYAxisInverted(self, flag=True): """Set the Y axis orientation. @@ -2290,6 +2337,8 @@ class PlotWidget(qt.QMainWindow): :param bool flag: True to respect data aspect ratio """ flag = bool(flag) + if flag == self.isKeepDataAspectRatio(): + return self._backend.setKeepDataAspectRatio(flag=flag) self._setDirtyPlot() self._forceResetZoom() @@ -2323,8 +2372,8 @@ class PlotWidget(qt.QMainWindow): # Defaults def isDefaultPlotPoints(self): - """Return True if default Curve symbol is 'o', False for no symbol.""" - return self._defaultPlotPoints == 'o' + """Return True if the default Curve symbol is set and False if not.""" + return self._defaultPlotPoints == silx.config.DEFAULT_PLOT_SYMBOL def setDefaultPlotPoints(self, flag): """Set the default symbol of all curves. @@ -2334,7 +2383,7 @@ class PlotWidget(qt.QMainWindow): :param bool flag: True to use 'o' as the default curve symbol, False to use no symbol. """ - self._defaultPlotPoints = 'o' if flag else '' + self._defaultPlotPoints = silx.config.DEFAULT_PLOT_SYMBOL if flag else '' # Reset symbol of all curves curves = self.getAllCurves(just_legend=False, withhidden=True) @@ -2510,7 +2559,7 @@ class PlotWidget(qt.QMainWindow): elif ddict['event'] == 'mouseClicked' and ddict['button'] == 'left': self.setActiveCurve(None) - def saveGraph(self, filename, fileFormat=None, dpi=None, **kw): + def saveGraph(self, filename, fileFormat=None, dpi=None): """Save a snapshot of the plot. Supported file formats depends on the backend in use. @@ -2523,9 +2572,6 @@ class PlotWidget(qt.QMainWindow): :param str fileFormat: String specifying the format :return: False if cannot save the plot, True otherwise """ - if kw: - _logger.warning('Extra parameters ignored: %s', str(kw)) - if fileFormat is None: if not hasattr(filename, 'lower'): _logger.warning( @@ -3080,149 +3126,3 @@ class PlotWidget(qt.QMainWindow): # Only call base class implementation when key is not handled. # See QWidget.keyPressEvent for details. super(PlotWidget, self).keyPressEvent(event) - - # Deprecated # - - def isDrawModeEnabled(self): - """Deprecated, use :meth:`getInteractiveMode` instead. - - Return True if the current interactive state is drawing.""" - _logger.warning( - 'isDrawModeEnabled deprecated, use getInteractiveMode instead') - return self.getInteractiveMode()['mode'] == 'draw' - - def setDrawModeEnabled(self, flag=True, shape='polygon', label=None, - color=None, **kwargs): - """Deprecated, use :meth:`setInteractiveMode` instead. - - Set the drawing mode if flag is True and its parameters. - - If flag is False, only item selection is enabled. - - Warning: Zoom and drawing are not compatible and cannot be enabled - simultaneously. - - :param bool flag: True to enable drawing and disable zoom and select. - :param str shape: Type of item to be drawn in: - hline, vline, rectangle, polygon (default) - :param str label: Associated text for identifying draw signals - :param color: The color to use to draw the selection area - :type color: string ("#RRGGBB") or 4 column unsigned byte array or - one of the predefined color names defined in colors.py - """ - _logger.warning( - 'setDrawModeEnabled deprecated, use setInteractiveMode instead') - - if kwargs: - _logger.warning('setDrawModeEnabled ignores additional parameters') - - if color is None: - color = 'black' - - if flag: - self.setInteractiveMode('draw', shape=shape, - label=label, color=color) - elif self.getInteractiveMode()['mode'] == 'draw': - self.setInteractiveMode('select') - - def getDrawMode(self): - """Deprecated, use :meth:`getInteractiveMode` instead. - - Return the draw mode parameters as a dict of None. - - It returns None if the interactive mode is not a drawing mode, - otherwise, it returns a dict containing the drawing mode parameters - as provided to :meth:`setDrawModeEnabled`. - """ - _logger.warning( - 'getDrawMode deprecated, use getInteractiveMode instead') - mode = self.getInteractiveMode() - return mode if mode['mode'] == 'draw' else None - - def isZoomModeEnabled(self): - """Deprecated, use :meth:`getInteractiveMode` instead. - - Return True if the current interactive state is zooming.""" - _logger.warning( - 'isZoomModeEnabled deprecated, use getInteractiveMode instead') - return self.getInteractiveMode()['mode'] == 'zoom' - - def setZoomModeEnabled(self, flag=True, color=None): - """Deprecated, use :meth:`setInteractiveMode` instead. - - Set the zoom mode if flag is True, else item selection is enabled. - - Warning: Zoom and drawing are not compatible and cannot be enabled - simultaneously - - :param bool flag: If True, enable zoom and select mode. - :param color: The color to use to draw the selection area. - (Default: 'black') - :param color: The color to use to draw the selection area - :type color: string ("#RRGGBB") or 4 column unsigned byte array or - one of the predefined color names defined in colors.py - """ - _logger.warning( - 'setZoomModeEnabled deprecated, use setInteractiveMode instead') - if color is None: - color = 'black' - - if flag: - self.setInteractiveMode('zoom', color=color) - elif self.getInteractiveMode()['mode'] == 'zoom': - self.setInteractiveMode('select') - - def insertMarker(self, *args, **kwargs): - """Deprecated, use :meth:`addMarker` instead.""" - _logger.warning( - 'insertMarker deprecated, use addMarker instead.') - return self.addMarker(*args, **kwargs) - - def insertXMarker(self, *args, **kwargs): - """Deprecated, use :meth:`addXMarker` instead.""" - _logger.warning( - 'insertXMarker deprecated, use addXMarker instead.') - return self.addXMarker(*args, **kwargs) - - def insertYMarker(self, *args, **kwargs): - """Deprecated, use :meth:`addYMarker` instead.""" - _logger.warning( - 'insertYMarker deprecated, use addYMarker instead.') - return self.addYMarker(*args, **kwargs) - - def isActiveCurveHandlingEnabled(self): - """Deprecated, use :meth:`isActiveCurveHandling` instead.""" - _logger.warning( - 'isActiveCurveHandlingEnabled deprecated, ' - 'use isActiveCurveHandling instead.') - return self.isActiveCurveHandling() - - def enableActiveCurveHandling(self, *args, **kwargs): - """Deprecated, use :meth:`setActiveCurveHandling` instead.""" - _logger.warning( - 'enableActiveCurveHandling deprecated, ' - 'use setActiveCurveHandling instead.') - return self.setActiveCurveHandling(*args, **kwargs) - - def invertYAxis(self, *args, **kwargs): - """Deprecated, use :meth:`Axis.setInverted` instead.""" - _logger.warning('invertYAxis deprecated, ' - 'use getYAxis().setInverted instead.') - return self.getYAxis().setInverted(*args, **kwargs) - - def showGrid(self, flag=True): - """Deprecated, use :meth:`setGraphGrid` instead.""" - _logger.warning("showGrid deprecated, use setGraphGrid instead") - if flag in (0, False): - flag = None - elif flag in (1, True): - flag = 'major' - else: - flag = 'both' - return self.setGraphGrid(flag) - - def keepDataAspectRatio(self, *args, **kwargs): - """Deprecated, use :meth:`setKeepDataAspectRatio`.""" - _logger.warning('keepDataAspectRatio deprecated,' - 'use setKeepDataAspectRatio instead') - return self.setKeepDataAspectRatio(*args, **kwargs) diff --git a/silx/gui/plot/PlotWindow.py b/silx/gui/plot/PlotWindow.py index 23ea399..b44a512 100644 --- a/silx/gui/plot/PlotWindow.py +++ b/silx/gui/plot/PlotWindow.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2018 European Synchrotron Radiation Facility +# Copyright (c) 2004-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -29,7 +29,7 @@ The :class:`PlotWindow` is a subclass of :class:`.PlotWidget`. __authors__ = ["V.A. Sole", "T. Vincent"] __license__ = "MIT" -__date__ = "12/10/2018" +__date__ = "21/12/2018" import collections import logging @@ -217,10 +217,8 @@ class PlotWindow(PlotWidget): # Make colorbar background white self._colorbar.setAutoFillBackground(True) - palette = self._colorbar.palette() - palette.setColor(qt.QPalette.Background, qt.Qt.white) - palette.setColor(qt.QPalette.Window, qt.Qt.white) - self._colorbar.setPalette(palette) + self._sigAxesVisibilityChanged.connect(self._updateColorBarBackground) + self._updateColorBarBackground() gridLayout = qt.QGridLayout() gridLayout.setSpacing(0) @@ -294,6 +292,43 @@ class PlotWindow(PlotWidget): for action in toolbar.actions(): self.addAction(action) + def setBackgroundColor(self, color): + super(PlotWindow, self).setBackgroundColor(color) + self._updateColorBarBackground() + + setBackgroundColor.__doc__ = PlotWidget.setBackgroundColor.__doc__ + + def setDataBackgroundColor(self, color): + super(PlotWindow, self).setDataBackgroundColor(color) + self._updateColorBarBackground() + + setDataBackgroundColor.__doc__ = PlotWidget.setDataBackgroundColor.__doc__ + + def setForegroundColor(self, color): + super(PlotWindow, self).setForegroundColor(color) + self._updateColorBarBackground() + + setForegroundColor.__doc__ = PlotWidget.setForegroundColor.__doc__ + + def _updateColorBarBackground(self): + """Update the colorbar background according to the state of the plot""" + if self._isAxesDisplayed(): + color = self.getBackgroundColor() + else: + color = self.getDataBackgroundColor() + if not color.isValid(): + # If no color defined, use the background one + color = self.getBackgroundColor() + + foreground = self.getForegroundColor() + + palette = self._colorbar.palette() + palette.setColor(qt.QPalette.Background, color) + palette.setColor(qt.QPalette.Window, color) + palette.setColor(qt.QPalette.WindowText, foreground) + palette.setColor(qt.QPalette.Text, foreground) + self._colorbar.setPalette(palette) + def getInteractiveModeToolBar(self): """Returns QToolBar controlling interactive mode. @@ -457,10 +492,6 @@ class PlotWindow(PlotWidget): return self._colorbar # getters for dock widgets - @property - @deprecated(replacement="getLegendsDockWidget()", since_version="0.4.0") - def legendsDockWidget(self): - return self.getLegendsDockWidget() def getLegendsDockWidget(self): """DockWidget with Legend panel""" @@ -470,11 +501,6 @@ class PlotWindow(PlotWidget): self.addTabbedDockWidget(self._legendsDockWidget) return self._legendsDockWidget - @property - @deprecated(replacement="getCurvesRoiWidget()", since_version="0.4.0") - def curvesROIDockWidget(self): - return self.getCurvesRoiDockWidget() - def getCurvesRoiDockWidget(self): # Undocumented for a "soft deprecation" in version 0.7.0 # (still used internally for lazy loading) @@ -496,11 +522,6 @@ class PlotWindow(PlotWidget): """ return self.getCurvesRoiDockWidget().roiWidget - @property - @deprecated(replacement="getMaskToolsDockWidget()", since_version="0.4.0") - def maskToolsDockWidget(self): - return self.getMaskToolsDockWidget() - def getMaskToolsDockWidget(self): """DockWidget with image mask panel (lazy-loaded).""" if self._maskToolsDockWidget is None: @@ -539,11 +560,6 @@ class PlotWindow(PlotWidget): def panModeAction(self): return self.getInteractiveModeToolBar().getPanModeAction() - @property - @deprecated(replacement="getConsoleAction()", since_version="0.4.0") - def consoleAction(self): - return self.getConsoleAction() - def getConsoleAction(self): """QAction handling the IPython console activation. @@ -563,11 +579,6 @@ class PlotWindow(PlotWidget): self._consoleAction.setEnabled(False) return self._consoleAction - @property - @deprecated(replacement="getCrosshairAction()", since_version="0.4.0") - def crosshairAction(self): - return self.getCrosshairAction() - def getCrosshairAction(self): """Action toggling crosshair cursor mode. @@ -577,11 +588,6 @@ class PlotWindow(PlotWidget): self._crosshairAction = actions.control.CrosshairAction(self, color='red') return self._crosshairAction - @property - @deprecated(replacement="getMaskAction()", since_version="0.4.0") - def maskAction(self): - return self.getMaskAction() - def getMaskAction(self): """QAction toggling image mask dock widget @@ -589,12 +595,6 @@ class PlotWindow(PlotWidget): """ return self.getMaskToolsDockWidget().toggleViewAction() - @property - @deprecated(replacement="getPanWithArrowKeysAction()", - since_version="0.4.0") - def panWithArrowKeysAction(self): - return self.getPanWithArrowKeysAction() - def getPanWithArrowKeysAction(self): """Action toggling pan with arrow keys. @@ -604,11 +604,6 @@ class PlotWindow(PlotWidget): self._panWithArrowKeysAction = actions.control.PanWithArrowKeysAction(self) return self._panWithArrowKeysAction - @property - @deprecated(replacement="getRoiAction()", since_version="0.4.0") - def roiAction(self): - return self.getRoiAction() - def getStatsAction(self): if self._statsAction is None: self._statsAction = qt.QAction('Curves stats', self) diff --git a/silx/gui/plot/PrintPreviewToolButton.py b/silx/gui/plot/PrintPreviewToolButton.py index b48505d..d857c18 100644 --- a/silx/gui/plot/PrintPreviewToolButton.py +++ b/silx/gui/plot/PrintPreviewToolButton.py @@ -111,10 +111,11 @@ from .. import icons from . import PlotWidget from ..widgets.PrintPreview import PrintPreviewDialog, SingletonPrintPreviewDialog from ..widgets.PrintGeometryDialog import PrintGeometryDialog +from silx.utils.deprecation import deprecated __authors__ = ["P. Knobel"] __license__ = "MIT" -__date__ = "18/07/2017" +__date__ = "20/12/2018" _logger = logging.getLogger(__name__) # _logger.setLevel(logging.DEBUG) @@ -132,19 +133,19 @@ class PrintPreviewToolButton(qt.QToolButton): if not isinstance(plot, PlotWidget): raise TypeError("plot parameter must be a PlotWidget") - self.plot = plot + self._plot = plot self.setIcon(icons.getQIcon('document-print')) printGeomAction = qt.QAction("Print geometry", self) printGeomAction.setToolTip("Define a print geometry prior to sending " "the plot to the print preview dialog") - printGeomAction.setIcon(icons.getQIcon('shape-rectangle')) # fixme: icon not displayed in menu + printGeomAction.setIcon(icons.getQIcon('shape-rectangle')) printGeomAction.triggered.connect(self._setPrintConfiguration) printPreviewAction = qt.QAction("Print preview", self) printPreviewAction.setToolTip("Send plot to the print preview dialog") - printPreviewAction.setIcon(icons.getQIcon('document-print')) # fixme: icon not displayed + printPreviewAction.setIcon(icons.getQIcon('document-print')) printPreviewAction.triggered.connect(self._plotToPrintPreview) menu = qt.QMenu(self) @@ -172,24 +173,64 @@ class PrintPreviewToolButton(qt.QToolButton): self._printPreviewDialog = PrintPreviewDialog(self.parent()) return self._printPreviewDialog + def getTitle(self): + """Implement this method to fetch the title in the plot. + + :return: Title to be printed above the plot, or None (no title added) + :rtype: str or None + """ + return None + + def getCommentAndPosition(self): + """Implement this method to fetch the legend to be printed below the + figure and its position. + + :return: Legend to be printed below the figure and its position: + "CENTER", "LEFT" or "RIGHT" + :rtype: (str, str) or (None, None) + """ + return None, None + + @property + @deprecated(since_version="0.10", + replacement="getPlot()") + def plot(self): + return self._plot + + def getPlot(self): + """Return the :class:`.PlotWidget` associated with this tool button. + + :rtype: :class:`.PlotWidget` + """ + return self._plot + def _plotToPrintPreview(self): """Grab the plot widget and send it to the print preview dialog. Make sure the print preview dialog is shown and raised.""" if not self.printPreviewDialog.ensurePrinterIsSet(): return + comment, commentPosition = self.getCommentAndPosition() + if qt.HAS_SVG: svgRenderer, viewBox = self._getSvgRendererAndViewbox() self.printPreviewDialog.addSvgItem(svgRenderer, - viewBox=viewBox) + title=self.getTitle(), + comment=comment, + commentPosition=commentPosition, + viewBox=viewBox, + keepRatio=self._printGeometry["keepAspectRatio"]) else: _logger.warning("Missing QtSvg library, using a raster image") if qt.BINDING in ["PyQt4", "PySide"]: - pixmap = qt.QPixmap.grabWidget(self.plot.centralWidget()) + pixmap = qt.QPixmap.grabWidget(self._plot.centralWidget()) else: # PyQt5 and hopefully PyQt6+ - pixmap = self.plot.centralWidget().grab() - self.printPreviewDialog.addPixmap(pixmap) + pixmap = self._plot.centralWidget().grab() + self.printPreviewDialog.addPixmap(pixmap, + title=self.getTitle(), + comment=comment, + commentPosition=commentPosition) self.printPreviewDialog.show() self.printPreviewDialog.raise_() @@ -201,7 +242,7 @@ class PrintPreviewToolButton(qt.QToolButton): and to the geometry configuration (width, height, ratio) specified by the user.""" imgData = StringIO() - assert self.plot.saveGraph(imgData, fileFormat="svg"), \ + assert self._plot.saveGraph(imgData, fileFormat="svg"), \ "Unable to save graph" imgData.flush() imgData.seek(0) @@ -310,7 +351,7 @@ class PrintPreviewToolButton(qt.QToolButton): self._printGeometry = self._printConfigurationDialog.getPrintGeometry() def _getPlotAspectRatio(self): - widget = self.plot.centralWidget() + widget = self._plot.centralWidget() graphWidth = float(widget.width()) graphHeight = float(widget.height()) return graphHeight / graphWidth diff --git a/silx/gui/plot/Profile.py b/silx/gui/plot/Profile.py index 182cf60..46e4523 100644 --- a/silx/gui/plot/Profile.py +++ b/silx/gui/plot/Profile.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2018 European Synchrotron Radiation Facility +# Copyright (c) 2004-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -180,7 +180,8 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method): :type scale: 2-tuple of float :param int lineWidth: width of the profile line :param str method: method to compute the profile. Can be 'mean' or 'sum' - :return: `profile, area, profileName, xLabel`, where: + :return: `coords, profile, area, profileName, xLabel`, where: + - coords is the X coordinate to use to display the profile - profile is a 2D array of the profiles of the stack of images. For a single image, the profile is a curve, so this parameter has a shape *(1, len(curve))* @@ -188,10 +189,9 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method): the effective ROI area corners in plot coords. - profileName is a string describing the ROI, meant to be used as title of the profile plot - - xLabel is a string describing the meaning of the X axis on the - profile plot ("rows", "columns", "distance") + - xLabel the label for X in the profile window - :rtype: tuple(ndarray, (ndarray, ndarray), str, str) + :rtype: tuple(ndarray,ndarray,(ndarray,ndarray),str) """ if currentData is None or roiInfo is None or lineWidth is None: raise ValueError("createProfile called with invalide arguments") @@ -212,12 +212,15 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method): axis=0, method=method) + coords = numpy.arange(len(profile[0]), dtype=numpy.float32) + coords = coords * scale[0] + origin[0] + yMin, yMax = min(area[1]), max(area[1]) - 1 if roiWidth <= 1: profileName = 'Y = %g' % yMin else: profileName = 'Y = [%g, %g]' % (yMin, yMax) - xLabel = 'Columns' + xLabel = 'X' elif lineProjectionMode == 'Y': # Vertical profile on the whole image profile, area = _alignedFullProfile(currentData3D, @@ -226,12 +229,15 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method): axis=1, method=method) + coords = numpy.arange(len(profile[0]), dtype=numpy.float32) + coords = coords * scale[1] + origin[1] + xMin, xMax = min(area[0]), max(area[0]) - 1 if roiWidth <= 1: profileName = 'X = %g' % xMin else: profileName = 'X = [%g, %g]' % (xMin, xMax) - xLabel = 'Rows' + xLabel = 'Y' else: # Free line profile @@ -306,35 +312,52 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method): dCol = (endPt[1] - startPt[1]) / length # Extend ROI with half a pixel on each end - startPt = startPt[0] - 0.5 * dRow, startPt[1] - 0.5 * dCol - endPt = endPt[0] + 0.5 * dRow, endPt[1] + 0.5 * dCol + roiStartPt = startPt[0] - 0.5 * dRow, startPt[1] - 0.5 * dCol + roiEndPt = endPt[0] + 0.5 * dRow, endPt[1] + 0.5 * dCol # Rotate deltas by 90 degrees to apply line width dRow, dCol = dCol, -dRow area = ( - numpy.array((startPt[1] - 0.5 * roiWidth * dCol, - startPt[1] + 0.5 * roiWidth * dCol, - endPt[1] + 0.5 * roiWidth * dCol, - endPt[1] - 0.5 * roiWidth * dCol), + numpy.array((roiStartPt[1] - 0.5 * roiWidth * dCol, + roiStartPt[1] + 0.5 * roiWidth * dCol, + roiEndPt[1] + 0.5 * roiWidth * dCol, + roiEndPt[1] - 0.5 * roiWidth * dCol), dtype=numpy.float32) * scale[0] + origin[0], - numpy.array((startPt[0] - 0.5 * roiWidth * dRow, - startPt[0] + 0.5 * roiWidth * dRow, - endPt[0] + 0.5 * roiWidth * dRow, - endPt[0] - 0.5 * roiWidth * dRow), + numpy.array((roiStartPt[0] - 0.5 * roiWidth * dRow, + roiStartPt[0] + 0.5 * roiWidth * dRow, + roiEndPt[0] + 0.5 * roiWidth * dRow, + roiEndPt[0] - 0.5 * roiWidth * dRow), dtype=numpy.float32) * scale[1] + origin[1]) - y0, x0 = startPt - y1, x1 = endPt - if x1 == x0 or y1 == y0: - profileName = 'From (%g, %g) to (%g, %g)' % (x0, y0, x1, y1) + # Convert start and end points back to plot coords + y0 = startPt[0] * scale[1] + origin[1] + x0 = startPt[1] * scale[0] + origin[0] + y1 = endPt[0] * scale[1] + origin[1] + x1 = endPt[1] * scale[0] + origin[0] + + if startPt[1] == endPt[1]: + profileName = 'X = %g; Y = [%g, %g]' % (x0, y0, y1) + coords = numpy.arange(len(profile[0]), dtype=numpy.float32) + coords = coords * scale[1] + y0 + xLabel = 'Y' + + elif startPt[0] == endPt[0]: + profileName = 'Y = %g; X = [%g, %g]' % (y0, x0, x1) + coords = numpy.arange(len(profile[0]), dtype=numpy.float32) + coords = coords * scale[0] + x0 + xLabel = 'X' + else: m = (y1 - y0) / (x1 - x0) b = y0 - m * x0 profileName = 'y = %g * x %+g ; width=%d' % (m, b, roiWidth) - xLabel = 'Distance' + coords = numpy.linspace(x0, x1, len(profile[0]), + endpoint=True, + dtype=numpy.float32) + xLabel = 'X' - return profile, area, profileName, xLabel + return coords, profile, area, profileName, xLabel # ProfileToolBar ############################################################## @@ -458,7 +481,7 @@ class ProfileToolBar(qt.QToolBar): self.addWidget(self.lineWidthSpinBox) self.methodsButton = ProfileOptionToolButton(parent=self, plot=self) - self.addWidget(self.methodsButton) + self.__profileOptionToolAction = self.addWidget(self.methodsButton) # TODO: add connection with the signal self.methodsButton.sigMethodChanged.connect(self.setProfileMethod) @@ -650,7 +673,7 @@ class ProfileToolBar(qt.QToolBar): if self._roiInfo is None: return - profile, area, profileName, xLabel = createProfile( + coords, profile, area, profileName, xLabel = createProfile( roiInfo=self._roiInfo, currentData=currentData, origin=origin, @@ -658,28 +681,25 @@ class ProfileToolBar(qt.QToolBar): lineWidth=self.lineWidthSpinBox.value(), method=method) - self.getProfilePlot().setGraphTitle(profileName) + profilePlot = self.getProfilePlot() + + profilePlot.setGraphTitle(profileName) + profilePlot.getXAxis().setLabel(xLabel) dataIs3D = len(currentData.shape) > 2 if dataIs3D: - self.getProfilePlot().addImage(profile, - legend=profileName, - xlabel=xLabel, - ylabel="Frame index (depth)", - colormap=colormap) + profileScale = (coords[-1] - coords[0]) / profile.shape[1], 1 + profilePlot.addImage(profile, + legend=profileName, + colormap=colormap, + origin=(coords[0], 0), + scale=profileScale) + profilePlot.getYAxis().setLabel("Frame index (depth)") else: - coords = numpy.arange(len(profile[0]), dtype=numpy.float32) - # Scale horizontal and vertical profile coordinates - if self._roiInfo[2] == 'X': - coords = coords * scale[0] + origin[0] - elif self._roiInfo[2] == 'Y': - coords = coords * scale[1] + origin[1] - - self.getProfilePlot().addCurve(coords, - profile[0], - legend=profileName, - xlabel=xLabel, - color=self.overlayColor) + profilePlot.addCurve(coords, + profile[0], + legend=profileName, + color=self.overlayColor) self.plot.addItem(area[0], area[1], legend=self._POLYGON_LEGEND, @@ -732,6 +752,9 @@ class ProfileToolBar(qt.QToolBar): def getProfileMethod(self): return self._method + def getProfileOptionToolAction(self): + return self.__profileOptionToolAction + class Profile3DToolBar(ProfileToolBar): def __init__(self, parent=None, stackview=None, diff --git a/silx/gui/plot/ScatterMaskToolsWidget.py b/silx/gui/plot/ScatterMaskToolsWidget.py index de645be..0c6797f 100644 --- a/silx/gui/plot/ScatterMaskToolsWidget.py +++ b/silx/gui/plot/ScatterMaskToolsWidget.py @@ -35,7 +35,7 @@ from __future__ import division __authors__ = ["P. Knobel"] __license__ = "MIT" -__date__ = "24/04/2018" +__date__ = "15/02/2019" import math @@ -152,6 +152,22 @@ class ScatterMask(BaseMask): stencil = (y - cy)**2 + (x - cx)**2 < radius**2 self.updateStencil(level, stencil, mask) + def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True): + """Mask/Unmask an ellipse of the given mask level. + + :param int level: Mask level to update. + :param int crow: Row of the center of the ellipse + :param int ccol: Column of the center of the ellipse + :param float radius_r: Radius of the ellipse in the row + :param float radius_c: Radius of the ellipse in the column + :param bool mask: True to mask (default), False to unmask. + """ + def is_inside(px, py): + return (px - ccol)**2 / radius_c**2 + (py - crow)**2 / radius_r**2 <= 1.0 + x, y = self._getXY() + indices_inside = [idx for idx in range(len(x)) if is_inside(x[idx], y[idx])] + self.updatePoints(level, indices_inside, mask) + def updateLine(self, level, y0, x0, y1, x1, width, mask=True): """Mask/Unmask points inside a rectangle defined by a line (two end points) and a width. @@ -490,26 +506,35 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget): level = self.levelSpinBox.value() - if (self._drawingMode == 'rectangle' and - event['event'] == 'drawingFinished'): - doMask = self._isMasking() + if self._drawingMode == 'rectangle': + if event['event'] == 'drawingFinished': + doMask = self._isMasking() + + self._mask.updateRectangle( + level, + y=event['y'], + x=event['x'], + height=abs(event['height']), + width=abs(event['width']), + mask=doMask) + self._mask.commit() - self._mask.updateRectangle( - level, - y=event['y'], - x=event['x'], - height=abs(event['height']), - width=abs(event['width']), - mask=doMask) - self._mask.commit() + elif self._drawingMode == 'ellipse': + if event['event'] == 'drawingFinished': + doMask = self._isMasking() + center = event['points'][0] + size = event['points'][1] + self._mask.updateEllipse(level, center[1], center[0], + size[1], size[0], doMask) + self._mask.commit() - elif (self._drawingMode == 'polygon' and - event['event'] == 'drawingFinished'): - doMask = self._isMasking() - vertices = event['points'] - vertices = vertices[:, (1, 0)] # (y, x) - self._mask.updatePolygon(level, vertices, doMask) - self._mask.commit() + elif self._drawingMode == 'polygon': + if event['event'] == 'drawingFinished': + doMask = self._isMasking() + vertices = event['points'] + vertices = vertices[:, (1, 0)] # (y, x) + self._mask.updatePolygon(level, vertices, doMask) + self._mask.commit() elif self._drawingMode == 'pencil': doMask = self._isMasking() @@ -536,6 +561,8 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget): self._lastPencilPos = None else: self._lastPencilPos = y, x + else: + _logger.error("Drawing mode %s unsupported", self._drawingMode) def _loadRangeFromColormapTriggered(self): """Set range from active scatter colormap range""" diff --git a/silx/gui/plot/ScatterView.py b/silx/gui/plot/ScatterView.py index ae79cf9..5fc66ef 100644 --- a/silx/gui/plot/ScatterView.py +++ b/silx/gui/plot/ScatterView.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2018 European Synchrotron Radiation Facility +# Copyright (c) 2018-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -353,3 +353,13 @@ class ScatterView(qt.QMainWindow): return self.getPlotWidget().resetZoom(*args, **kwargs) resetZoom.__doc__ = PlotWidget.resetZoom.__doc__ + + def getSelectionMask(self, *args, **kwargs): + return self.getMaskToolsWidget().getSelectionMask(*args, **kwargs) + + getSelectionMask.__doc__ = ScatterMaskToolsWidget.getSelectionMask.__doc__ + + def setSelectionMask(self, *args, **kwargs): + return self.getMaskToolsWidget().setSelectionMask(*args, **kwargs) + + setSelectionMask.__doc__ = ScatterMaskToolsWidget.setSelectionMask.__doc__ diff --git a/silx/gui/plot/StackView.py b/silx/gui/plot/StackView.py index 72b6cd4..2a3d7e8 100644 --- a/silx/gui/plot/StackView.py +++ b/silx/gui/plot/StackView.py @@ -89,14 +89,8 @@ from silx.utils.array_like import DatasetView, ListOfImages from silx.math import calibration from silx.utils.deprecation import deprecated_warning -try: - import h5py -except ImportError: - def is_dataset(obj): - return False - h5py = None -else: - from silx.io.utils import is_dataset +import h5py +from silx.io.utils import is_dataset _logger = logging.getLogger(__name__) diff --git a/silx/gui/plot/StatsWidget.py b/silx/gui/plot/StatsWidget.py index bb66613..4ba4fab 100644 --- a/silx/gui/plot/StatsWidget.py +++ b/silx/gui/plot/StatsWidget.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -31,552 +31,1266 @@ __license__ = "MIT" __date__ = "24/07/2018" -import functools +from collections import OrderedDict +from contextlib import contextmanager import logging +import weakref + import numpy -from collections import OrderedDict -import silx.utils.weakref from silx.gui import qt from silx.gui import icons -from silx.gui.plot.items.curve import Curve as CurveItem -from silx.gui.plot.items.histogram import Histogram as HistogramItem -from silx.gui.plot.items.image import ImageBase as ImageItem -from silx.gui.plot.items.scatter import Scatter as ScatterItem from silx.gui.plot import stats as statsmdl from silx.gui.widgets.TableWidget import TableWidget from silx.gui.plot.stats.statshandler import StatsHandler, StatFormatter +from silx.gui.plot.items.core import ItemChangedType +from silx.gui.widgets.FlowLayout import FlowLayout +from . import PlotWidget +from . import items as plotitems -logger = logging.getLogger(__name__) +_logger = logging.getLogger(__name__) -class StatsWidget(qt.QWidget): + +# Helper class to handle specific calls to PlotWidget and SceneWidget + +class _Wrapper(qt.QObject): + """Base class for connection with PlotWidget and SceneWidget. + + This class is used when no PlotWidget or SceneWidget is connected. + + :param plot: The plot to be used """ - Widget displaying a set of :class:`Stat` to be displayed on a - :class:`StatsTable` and to be apply on items contained in the :class:`Plot` - Also contains options to: - * compute statistics on all the data or on visible data only - * show statistics of all items or only the active one + sigItemAdded = qt.Signal(object) + """Signal emitted when a new item is added. - :param parent: Qt parent - :param plot: the plot containing items on which we want statistics. + It provides the added item. """ - sigVisibilityChanged = qt.Signal(bool) + sigItemRemoved = qt.Signal(object) + """Signal emitted when an item is (about to be) removed. - NUMBER_FORMAT = '{0:.3f}' + It provides the removed item. + """ - class OptionsWidget(qt.QToolBar): - - def __init__(self, parent=None): - qt.QToolBar.__init__(self, parent) - self.setIconSize(qt.QSize(16, 16)) - - action = qt.QAction(self) - action.setIcon(icons.getQIcon("stats-active-items")) - action.setText("Active items only") - action.setToolTip("Display stats for active items only.") - action.setCheckable(True) - action.setChecked(True) - self.__displayActiveItems = action - - action = qt.QAction(self) - action.setIcon(icons.getQIcon("stats-whole-items")) - action.setText("All items") - action.setToolTip("Display stats for all available items.") - action.setCheckable(True) - self.__displayWholeItems = action - - action = qt.QAction(self) - action.setIcon(icons.getQIcon("stats-visible-data")) - action.setText("Use the visible data range") - action.setToolTip("Use the visible data range.<br/>" - "If activated the data is filtered to only use" - "visible data of the plot." - "The filtering is a data sub-sampling." - "No interpolation is made to fit data to" - "boundaries.") - action.setCheckable(True) - self.__useVisibleData = action - - action = qt.QAction(self) - action.setIcon(icons.getQIcon("stats-whole-data")) - action.setText("Use the full data range") - action.setToolTip("Use the full data range.") - action.setCheckable(True) - action.setChecked(True) - self.__useWholeData = action - - self.addAction(self.__displayWholeItems) - self.addAction(self.__displayActiveItems) - self.addSeparator() - self.addAction(self.__useVisibleData) - self.addAction(self.__useWholeData) - - self.itemSelection = qt.QActionGroup(self) - self.itemSelection.setExclusive(True) - self.itemSelection.addAction(self.__displayActiveItems) - self.itemSelection.addAction(self.__displayWholeItems) - - self.dataRangeSelection = qt.QActionGroup(self) - self.dataRangeSelection.setExclusive(True) - self.dataRangeSelection.addAction(self.__useWholeData) - self.dataRangeSelection.addAction(self.__useVisibleData) - - def isActiveItemMode(self): - return self.itemSelection.checkedAction() is self.__displayActiveItems - - def isVisibleDataRangeMode(self): - return self.dataRangeSelection.checkedAction() is self.__useVisibleData + sigCurrentChanged = qt.Signal(object) + """Signal emitted when the current item has changed. - def __init__(self, parent=None, plot=None, stats=None): - qt.QWidget.__init__(self, parent) - self.setLayout(qt.QVBoxLayout()) - self.layout().setContentsMargins(0, 0, 0, 0) - self._options = self.OptionsWidget(parent=self) - self.layout().addWidget(self._options) - self._statsTable = StatsTable(parent=self, plot=plot) - self.setStats = self._statsTable.setStats - self.setStats(stats) + It provides the current item. + """ - self.layout().addWidget(self._statsTable) - self.setPlot = self._statsTable.setPlot + sigVisibleDataChanged = qt.Signal() + """Signal emitted when the visible data area has changed""" - self._options.itemSelection.triggered.connect( - self._optSelectionChanged) - self._options.dataRangeSelection.triggered.connect( - self._optDataRangeChanged) - self._optSelectionChanged() - self._optDataRangeChanged() + def __init__(self, plot=None): + super(_Wrapper, self).__init__(parent=None) + self._plotRef = None if plot is None else weakref.ref(plot) - self.setDisplayOnlyActiveItem = self._statsTable.setDisplayOnlyActiveItem - self.setStatsOnVisibleData = self._statsTable.setStatsOnVisibleData + def getPlot(self): + """Returns the plot attached to this widget""" + return None if self._plotRef is None else self._plotRef() - def showEvent(self, event): - self.sigVisibilityChanged.emit(True) - qt.QWidget.showEvent(self, event) + def getItems(self): + """Returns the list of items in the plot - def hideEvent(self, event): - self.sigVisibilityChanged.emit(False) - qt.QWidget.hideEvent(self, event) + :rtype: List[object] + """ + return () - def _optSelectionChanged(self, action=None): - self._statsTable.setDisplayOnlyActiveItem(self._options.isActiveItemMode()) + def getSelectedItems(self): + """Returns the list of selected items in the plot - def _optDataRangeChanged(self, action=None): - self._statsTable.setStatsOnVisibleData(self._options.isVisibleDataRangeMode()) + :rtype: List[object] + """ + return () + def setCurrentItem(self, item): + """Set the current/active item in the plot -class BasicStatsWidget(StatsWidget): + :param item: The plot item to set as active/current + """ + pass + + def getLabel(self, item): + """Returns the label of the given item. + + :param item: + :rtype: str + """ + return '' + + def getKind(self, item): + """Returns the kind of an item or None if not supported + + :param item: + :rtype: Union[str,None] + """ + return None + + +class _PlotWidgetWrapper(_Wrapper): + """Class handling PlotWidget specific calls and signal connections + + See :class:`._Wrapper` for documentation + + :param PlotWidget plot: """ - Widget defining a simple set of :class:`Stat` to be displayed on a - :class:`StatsWidget`. - :param parent: Qt parent - :param plot: the plot containing items on which we want statistics. + def __init__(self, plot): + assert isinstance(plot, PlotWidget) + super(_PlotWidgetWrapper, self).__init__(plot) + plot.sigItemAdded.connect(self.sigItemAdded.emit) + plot.sigItemAboutToBeRemoved.connect(self.sigItemRemoved.emit) + plot.sigActiveCurveChanged.connect(self._activeCurveChanged) + plot.sigActiveImageChanged.connect(self._activeImageChanged) + plot.sigActiveScatterChanged.connect(self._activeScatterChanged) + plot.sigPlotSignal.connect(self._limitsChanged) + + def _activeChanged(self, kind): + """Handle change of active curve/image/scatter""" + plot = self.getPlot() + if plot is not None: + item = plot._getActiveItem(kind=kind) + if item is None or self.getKind(item) is not None: + self.sigCurrentChanged.emit(item) + + def _activeCurveChanged(self, previous, current): + self._activeChanged(kind='curve') + + def _activeImageChanged(self, previous, current): + self._activeChanged(kind='image') + + def _activeScatterChanged(self, previous, current): + self._activeChanged(kind='scatter') + + def _limitsChanged(self, event): + """Handle change of plot area limits.""" + if event['event'] == 'limitsChanged': + self.sigVisibleDataChanged.emit() + + def getItems(self): + plot = self.getPlot() + return () if plot is None else plot._getItems() + + def getSelectedItems(self): + plot = self.getPlot() + items = [] + if plot is not None: + for kind in plot._ACTIVE_ITEM_KINDS: + item = plot._getActiveItem(kind=kind) + if item is not None: + items.append(item) + return tuple(items) + + def setCurrentItem(self, item): + plot = self.getPlot() + if plot is not None: + kind = self.getKind(item) + if kind in plot._ACTIVE_ITEM_KINDS: + if plot._getActiveItem(kind) != item: + plot._setActiveItem(kind, item.getLegend()) + + def getLabel(self, item): + return item.getLegend() + + def getKind(self, item): + if isinstance(item, plotitems.Curve): + return 'curve' + elif isinstance(item, plotitems.ImageData): + return 'image' + elif isinstance(item, plotitems.Scatter): + return 'scatter' + elif isinstance(item, plotitems.Histogram): + return 'histogram' + else: + return None + + +class _SceneWidgetWrapper(_Wrapper): + """Class handling SceneWidget specific calls and signal connections + + See :class:`._Wrapper` for documentation + + :param SceneWidget plot: """ - STATS = StatsHandler(( - (statsmdl.StatMin(), StatFormatter()), - statsmdl.StatCoordMin(), - (statsmdl.StatMax(), StatFormatter()), - statsmdl.StatCoordMax(), - (('std', numpy.std), StatFormatter()), - (('mean', numpy.mean), StatFormatter()), - statsmdl.StatCOM() - )) + def __init__(self, plot): + # Lazy-import to avoid circular imports + from ..plot3d.SceneWidget import SceneWidget - def __init__(self, parent=None, plot=None): - StatsWidget.__init__(self, parent=parent, plot=plot, stats=self.STATS) + assert isinstance(plot, SceneWidget) + super(_SceneWidgetWrapper, self).__init__(plot) + plot.getSceneGroup().sigItemAdded.connect(self.sigItemAdded) + plot.getSceneGroup().sigItemRemoved.connect(self.sigItemRemoved) + plot.selection().sigCurrentChanged.connect(self._currentChanged) + # sigVisibleDataChanged is never emitted + + def _currentChanged(self, current, previous): + self.sigCurrentChanged.emit(current) + + def getItems(self): + plot = self.getPlot() + return () if plot is None else tuple(plot.getSceneGroup().visit()) + + def getSelectedItems(self): + plot = self.getPlot() + return () if plot is None else (plot.selection().getCurrentItem(),) + def setCurrentItem(self, item): + plot = self.getPlot() + if plot is not None: + plot.selection().setCurrentItem(item) -class StatsTable(TableWidget): + def getLabel(self, item): + return item.getLabel() + + def getKind(self, item): + from ..plot3d import items as plot3ditems + + if isinstance(item, (plot3ditems.ImageData, + plot3ditems.ScalarField3D)): + return 'image' + elif isinstance(item, (plot3ditems.Scatter2D, + plot3ditems.Scatter3D)): + return 'scatter' + else: + return None + + +class _ScalarFieldViewWrapper(_Wrapper): + """Class handling ScalarFieldView specific calls and signal connections + + See :class:`._Wrapper` for documentation + + :param SceneWidget plot: """ - TableWidget displaying for each curves contained by the Plot some - information: - * legend - * minimal value - * maximal value - * standard deviation (std) + def __init__(self, plot): + # Lazy-import to avoid circular imports + from ..plot3d.ScalarFieldView import ScalarFieldView + from ..plot3d.items import ScalarField3D + + assert isinstance(plot, ScalarFieldView) + super(_ScalarFieldViewWrapper, self).__init__(plot) + self._item = ScalarField3D() + self._dataChanged() + plot.sigDataChanged.connect(self._dataChanged) + # sigItemAdded, sigItemRemoved, sigVisibleDataChanged are never emitted + + def _dataChanged(self): + plot = self.getPlot() + if plot is not None: + self._item.setData(plot.getData(copy=False), copy=False) + self.sigCurrentChanged.emit(self._item) - :param parent: The widget's parent. - :param plot: :class:`.PlotWidget` instance on which to operate + def getItems(self): + plot = self.getPlot() + return () if plot is None else (self._item,) + + def getSelectedItems(self): + return self.getItems() + + def setCurrentItem(self, item): + pass + + def getLabel(self, item): + return 'Data' + + def getKind(self, item): + return 'image' + + +class _Container(object): + """Class to contain a plot item. + + This is apparently needed for compatibility with PySide2, + + :param QObject obj: """ + def __init__(self, obj): + self._obj = obj - COMPATIBLE_KINDS = { - 'curve': CurveItem, - 'image': ImageItem, - 'scatter': ScatterItem, - 'histogram': HistogramItem - } + def __call__(self): + return self._obj - COMPATIBLE_ITEMS = tuple(COMPATIBLE_KINDS.values()) - def __init__(self, parent=None, plot=None): - TableWidget.__init__(self, parent) - """Next freeID for the curve""" - self.plot = None - self._displayOnlyActItem = False - self._statsOnVisibleData = False - self._lgdAndKindToItems = {} - """Associate to a tuple(legend, kind) the items legend""" - self.callbackImage = None - self.callbackScatter = None - self.callbackCurve = None - """Associate the curve legend to his first item""" +class _StatsWidgetBase(object): + """ + Base class for all widgets which want to display statistics + """ + def __init__(self, statsOnVisibleData, displayOnlyActItem): + self._displayOnlyActItem = displayOnlyActItem + self._statsOnVisibleData = statsOnVisibleData self._statsHandler = None - self._legendsSet = [] - """list of legends actually displayed""" - self._resetColumns() - self.setColumnCount(len(self._columns)) - self.setSelectionBehavior(qt.QAbstractItemView.SelectRows) - self.setPlot(plot) - self.setSortingEnabled(True) + self.__default_skipped_events = ( + ItemChangedType.ALPHA, + ItemChangedType.COLOR, + ItemChangedType.COLORMAP, + ItemChangedType.SYMBOL, + ItemChangedType.SYMBOL_SIZE, + ItemChangedType.LINE_WIDTH, + ItemChangedType.LINE_STYLE, + ItemChangedType.LINE_BG_COLOR, + ItemChangedType.FILL, + ItemChangedType.HIGHLIGHTED_COLOR, + ItemChangedType.HIGHLIGHTED_STYLE, + ItemChangedType.TEXT, + ItemChangedType.OVERLAY, + ItemChangedType.VISUALIZATION_MODE, + ) + + self._plotWrapper = _Wrapper() + self._dealWithPlotConnection(create=True) - def _resetColumns(self): - self._columns_index = OrderedDict([('legend', 0), ('kind', 1)]) - self._columns = self._columns_index.keys() - self.setColumnCount(len(self._columns)) + def setPlot(self, plot): + """Define the plot to interact with - def setStats(self, statsHandler): + :param Union[PlotWidget,SceneWidget,None] plot: + The plot containing the items on which statistics are applied """ + try: + import OpenGL + except ImportError: + has_opengl = False + else: + has_opengl = True + from ..plot3d.SceneWidget import SceneWidget # Lazy import + self._dealWithPlotConnection(create=False) + self.clear() + if plot is None: + self._plotWrapper = _Wrapper() + elif isinstance(plot, PlotWidget): + self._plotWrapper = _PlotWidgetWrapper(plot) + else: + if has_opengl is True: + if isinstance(plot, SceneWidget): + self._plotWrapper = _SceneWidgetWrapper(plot) + else: # Expect a ScalarFieldView + self._plotWrapper = _ScalarFieldViewWrapper(plot) + else: + _logger.warning('OpenGL not installed, %s not managed' % ('SceneWidget qnd ScalarFieldView')) + self._dealWithPlotConnection(create=True) + + def setStats(self, statsHandler): + """Set which stats to display and the associated formatting. - :param statsHandler: Set the statistics to be displayed and how to - format them using - :rtype: :class:`StatsHandler` + :param StatsHandler statsHandler: + Set the statistics to be displayed and how to format them using """ - _statsHandler = statsHandler if statsHandler is None: - _statsHandler = StatsHandler(statFormatters=()) - if isinstance(_statsHandler, (list, tuple)): - _statsHandler = StatsHandler(_statsHandler) - assert isinstance(_statsHandler, StatsHandler) - self._resetColumns() - self.clear() - - for statName, stat in list(_statsHandler.stats.items()): - assert isinstance(stat, statsmdl.StatBase) - self._columns_index[statName] = len(self._columns_index) - self._statsHandler = _statsHandler - self._columns = self._columns_index.keys() - self.setColumnCount(len(self._columns)) + statsHandler = StatsHandler(statFormatters=()) + elif isinstance(statsHandler, (list, tuple)): + statsHandler = StatsHandler(statsHandler) + assert isinstance(statsHandler, StatsHandler) - self._updateItemObserve() - self._updateAllStats() + self._statsHandler = statsHandler def getStatsHandler(self): + """Returns the :class:`StatsHandler` in use. + + :rtype: StatsHandler + """ return self._statsHandler - def _updateAllStats(self): - for (legend, kind) in self._lgdAndKindToItems: - self._updateStats(legend, kind) + def getPlot(self): + """Returns the plot attached to this widget - @staticmethod - def _getKind(myItem): - if isinstance(myItem, CurveItem): - return 'curve' - elif isinstance(myItem, ImageItem): - return 'image' - elif isinstance(myItem, ScatterItem): - return 'scatter' - elif isinstance(myItem, HistogramItem): - return 'histogram' + :rtype: Union[PlotWidget,SceneWidget,None] + """ + return self._plotWrapper.getPlot() + + def _dealWithPlotConnection(self, create=True): + """Manage connection to plot signals + + Note: connection on Item are managed by _addItem and _removeItem methods + """ + connections = [] # List of (signal, slot) to connect/disconnect + if self._statsOnVisibleData: + connections.append( + (self._plotWrapper.sigVisibleDataChanged, self._updateAllStats)) + + if self._displayOnlyActItem: + connections.append( + (self._plotWrapper.sigCurrentChanged, self._updateItemObserve)) else: - return None + connections += [ + (self._plotWrapper.sigItemAdded, self._addItem), + (self._plotWrapper.sigItemRemoved, self._removeItem), + (self._plotWrapper.sigCurrentChanged, self._plotCurrentChanged)] + + for signal, slot in connections: + if create: + signal.connect(slot) + else: + signal.disconnect(slot) - def setPlot(self, plot): + def _updateItemObserve(self, *args): + """Reload table depending on mode""" + raise NotImplementedError('Base class') + + def _updateStats(self, item): + """Update displayed information for given plot item + + :param item: The plot item + """ + raise NotImplementedError('Base class') + + def _updateAllStats(self): + """Update stats for all rows in the table""" + raise NotImplementedError('Base class') + + def setDisplayOnlyActiveItem(self, displayOnlyActItem): + """Toggle display off all items or only the active/selected one + + :param bool displayOnlyActItem: + True if we want to only show active item """ - Define the plot to interact with + self._displayOnlyActItem = displayOnlyActItem + + def setStatsOnVisibleData(self, b): + """Toggle computation of statistics on whole data or only visible ones. + + .. warning:: When visible data is activated we will process to a simple + filtering of visible data by the user. The filtering is a + simple data sub-sampling. No interpolation is made to fit + data to boundaries. - :param plot: the plot containing the items on which statistics are - applied - :rtype: :class:`.PlotWidget` + :param bool b: True if we want to apply statistics only on visible data """ - if self.plot: + if self._statsOnVisibleData != b: self._dealWithPlotConnection(create=False) - self.plot = plot - self.clear() - if self.plot: + self._statsOnVisibleData = b self._dealWithPlotConnection(create=True) - self._updateItemObserve() + self._updateAllStats() - def _updateItemObserve(self): - if self.plot: - self.clear() - if self._displayOnlyActItem is True: - activeCurve = self.plot.getActiveCurve(just_legend=False) - activeScatter = self.plot._getActiveItem(kind='scatter', - just_legend=False) - activeImage = self.plot.getActiveImage(just_legend=False) - if activeCurve: - self._addItem(activeCurve) - if activeImage: - self._addItem(activeImage) - if activeScatter: - self._addItem(activeScatter) - else: - [self._addItem(curve) for curve in self.plot.getAllCurves()] - [self._addItem(image) for image in self.plot.getAllImages()] - scatters = self.plot._getItems(kind='scatter', - just_legend=False, - withhidden=True) - [self._addItem(scatter) for scatter in scatters] - histograms = self.plot._getItems(kind='histogram', - just_legend=False, - withhidden=True) - [self._addItem(histogram) for histogram in histograms] + def _addItem(self, item): + """Add a plot item to the table - def _dealWithPlotConnection(self, create=True): + If item is not supported, it is ignored. + + :param item: The plot item + :returns: True if the item is added to the widget. + :rtype: bool """ - Manage connection to plot signals + raise NotImplementedError('Base class') - Note: connection on Item are managed by the _removeItem function + def _removeItem(self, item): + """Remove table items corresponding to given plot item from the table. + + :param item: The plot item """ - if self.plot is None: - return - if self._displayOnlyActItem: - if create is True: - if self.callbackImage is None: - self.callbackImage = functools.partial(self._activeItemChanged, 'image') - self.callbackScatter = functools.partial(self._activeItemChanged, 'scatter') - self.callbackCurve = functools.partial(self._activeItemChanged, 'curve') - self.plot.sigActiveImageChanged.connect(self.callbackImage) - self.plot.sigActiveScatterChanged.connect(self.callbackScatter) - self.plot.sigActiveCurveChanged.connect(self.callbackCurve) - else: - if self.callbackImage is not None: - self.plot.sigActiveImageChanged.disconnect(self.callbackImage) - self.plot.sigActiveScatterChanged.disconnect(self.callbackScatter) - self.plot.sigActiveCurveChanged.disconnect(self.callbackCurve) - self.callbackImage = None - self.callbackScatter = None - self.callbackCurve = None - else: - if create is True: - self.plot.sigContentChanged.connect(self._plotContentChanged) - else: - self.plot.sigContentChanged.disconnect(self._plotContentChanged) - if create is True: - self.plot.sigPlotSignal.connect(self._zoomPlotChanged) - else: - self.plot.sigPlotSignal.disconnect(self._zoomPlotChanged) + raise NotImplementedError('Base class') + + def _plotCurrentChanged(self, current): + """Handle change of current item and update selection in table + + :param current: + """ + raise NotImplementedError('Base class') def clear(self): + """clear GUI""" + pass + + def _skipPlotItemChangedEvent(self, event): """ - Clear all existing items + + :param ItemChangedtype event: event to filter or not + :return: True if we want to ignore this ItemChangedtype + :rtype: bool """ - lgdsAndKinds = list(self._lgdAndKindToItems.keys()) - for lgdAndKind in lgdsAndKinds: - self._removeItem(legend=lgdAndKind[0], kind=lgdAndKind[1]) - self._lgdAndKindToItems = {} - qt.QTableWidget.clear(self) + return event in self.__default_skipped_events + + +class StatsTable(_StatsWidgetBase, TableWidget): + """ + TableWidget displaying for each curves contained by the Plot some + information: + + * legend + * minimal value + * maximal value + * standard deviation (std) + + :param QWidget parent: The widget's parent. + :param Union[PlotWidget,SceneWidget] plot: + :class:`PlotWidget` or :class:`SceneWidget` instance on which to operate + """ + + _LEGEND_HEADER_DATA = 'legend' + _KIND_HEADER_DATA = 'kind' + + def __init__(self, parent=None, plot=None): + TableWidget.__init__(self, parent) + _StatsWidgetBase.__init__(self, statsOnVisibleData=False, + displayOnlyActItem=False) + + # Init for _displayOnlyActItem == False + assert self._displayOnlyActItem is False + self.setSelectionBehavior(qt.QAbstractItemView.SelectRows) + self.setSelectionMode(qt.QAbstractItemView.SingleSelection) + self.currentItemChanged.connect(self._currentItemChanged) + self.setRowCount(0) + self.setColumnCount(2) - # It have to called befor3e accessing to the header items - self.setHorizontalHeaderLabels(list(self._columns)) - - if self._statsHandler is not None: - for columnId, name in enumerate(self._columns): - item = self.horizontalHeaderItem(columnId) - if name in self._statsHandler.stats: - stat = self._statsHandler.stats[name] - text = stat.name[0].upper() + stat.name[1:] - if stat.description is not None: - tooltip = stat.description - else: - tooltip = "" - else: - text = name[0].upper() + name[1:] - tooltip = "" - item.setToolTip(tooltip) - item.setText(text) + # Init headers + headerItem = qt.QTableWidgetItem('Legend') + headerItem.setData(qt.Qt.UserRole, self._LEGEND_HEADER_DATA) + self.setHorizontalHeaderItem(0, headerItem) + headerItem = qt.QTableWidgetItem('Kind') + headerItem.setData(qt.Qt.UserRole, self._KIND_HEADER_DATA) + self.setHorizontalHeaderItem(1, headerItem) + + self.setSortingEnabled(True) + self.setPlot(plot) - if hasattr(self.horizontalHeader(), 'setSectionResizeMode'): # Qt5 - self.horizontalHeader().setSectionResizeMode(qt.QHeaderView.ResizeToContents) + @contextmanager + def _disableSorting(self): + """Context manager that disables table sorting + + Previous state is restored when leaving + """ + sorting = self.isSortingEnabled() + if sorting: + self.setSortingEnabled(False) + yield + if sorting: + self.setSortingEnabled(sorting) + + def setStats(self, statsHandler): + """Set which stats to display and the associated formatting. + + :param StatsHandler statsHandler: + Set the statistics to be displayed and how to format them using + """ + self._removeAllItems() + _StatsWidgetBase.setStats(self, statsHandler) + + self.setRowCount(0) + self.setColumnCount(len(self._statsHandler.stats) + 2) # + legend and kind + + for index, stat in enumerate(self._statsHandler.stats.values()): + headerItem = qt.QTableWidgetItem(stat.name.capitalize()) + headerItem.setData(qt.Qt.UserRole, stat.name) + if stat.description is not None: + headerItem.setToolTip(stat.description) + self.setHorizontalHeaderItem(2 + index, headerItem) + + horizontalHeader = self.horizontalHeader() + if hasattr(horizontalHeader, 'setSectionResizeMode'): # Qt5 + horizontalHeader.setSectionResizeMode(qt.QHeaderView.ResizeToContents) else: # Qt4 - self.horizontalHeader().setResizeMode(qt.QHeaderView.ResizeToContents) - self.setColumnHidden(self._columns_index['kind'], True) + horizontalHeader.setResizeMode(qt.QHeaderView.ResizeToContents) - def _addItem(self, item): - assert isinstance(item, self.COMPATIBLE_ITEMS) - if (item.getLegend(), self._getKind(item)) in self._lgdAndKindToItems: - self._updateStats(item.getLegend(), self._getKind(item)) - return + self._updateItemObserve() + + def setPlot(self, plot): + """Define the plot to interact with + + :param Union[PlotWidget,SceneWidget,None] plot: + The plot containing the items on which statistics are applied + """ + _StatsWidgetBase.setPlot(self, plot) + self._updateItemObserve() + + def clear(self): + """Define the plot to interact with + + :param Union[PlotWidget,SceneWidget,None] plot: + The plot containing the items on which statistics are applied + """ + self._removeAllItems() + + def _updateItemObserve(self, *args): + """Reload table depending on mode""" + self._removeAllItems() + + # Get selected or all items from the plot + if self._displayOnlyActItem: # Only selected + items = self._plotWrapper.getSelectedItems() + else: # All items + items = self._plotWrapper.getItems() + + # Add items to the plot + for item in items: + self._addItem(item) + + def _plotCurrentChanged(self, current): + """Handle change of current item and update selection in table - self.setRowCount(self.rowCount() + 1) - indexTable = self.rowCount() - 1 - kind = self._getKind(item) - - self._lgdAndKindToItems[(item.getLegend(), kind)] = {} - - # the get item will manage the item creation of not existing - _createItem = self._getItem - for itemName in self._columns: - _createItem(name=itemName, legend=item.getLegend(), kind=kind, - indexTable=indexTable) - - self._updateStats(legend=item.getLegend(), kind=kind) - - callback = functools.partial( - silx.utils.weakref.WeakMethodProxy(self._updateStats), - item.getLegend(), kind) - item.sigItemChanged.connect(callback) - self.setColumnHidden(self._columns_index['kind'], - item.getLegend() not in self._legendsSet) - self._legendsSet.append(item.getLegend()) - - def _getItem(self, name, legend, kind, indexTable): - if (legend, kind) not in self._lgdAndKindToItems: - self._lgdAndKindToItems[(legend, kind)] = {} - if not (name in self._lgdAndKindToItems[(legend, kind)] and - self._lgdAndKindToItems[(legend, kind)]): - if name in ('legend', 'kind'): - _item = qt.QTableWidgetItem(type=qt.QTableWidgetItem.Type) - if name == 'legend': - _item.setText(legend) + :param current: + """ + row = self._itemToRow(current) + if row is None: + if self.currentRow() >= 0: + self.setCurrentCell(-1, -1) + elif row != self.currentRow(): + self.setCurrentCell(row, 0) + + def _tableItemToItem(self, tableItem): + """Find the plot item corresponding to a table item + + :param QTableWidgetItem tableItem: + :rtype: QObject + """ + container = tableItem.data(qt.Qt.UserRole) + return container() + + def _itemToRow(self, item): + """Find the row corresponding to a plot item + + :param item: The plot item + :return: The corresponding row index + :rtype: Union[int,None] + """ + for row in range(self.rowCount()): + tableItem = self.item(row, 0) + if self._tableItemToItem(tableItem) == item: + return row + return None + + def _itemToTableItems(self, item): + """Find all table items corresponding to a plot item + + :param item: The plot item + :return: An ordered dict of column name to QTableWidgetItem mapping + for the given plot item. + :rtype: OrderedDict + """ + result = OrderedDict() + row = self._itemToRow(item) + if row is not None: + for column in range(self.columnCount()): + tableItem = self.item(row, column) + if self._tableItemToItem(tableItem) != item: + _logger.error("Table item/plot item mismatch") else: - assert name == 'kind' - _item.setText(kind) + header = self.horizontalHeaderItem(column) + name = header.data(qt.Qt.UserRole) + result[name] = tableItem + return result + + def _plotItemChanged(self, event): + """Handle modifications of the items. + + :param event: + """ + if self._skipPlotItemChangedEvent(event) is True: + return + else: + item = self.sender() + self._updateStats(item) + + def _addItem(self, item): + """Add a plot item to the table + + If item is not supported, it is ignored. + + :param item: The plot item + :returns: True if the item is added to the widget. + :rtype: bool + """ + if self._itemToRow(item) is not None: + _logger.info("Item already present in the table") + self._updateStats(item) + return True + + kind = self._plotWrapper.getKind(item) + if kind not in statsmdl.BASIC_COMPATIBLE_KINDS: + _logger.info("Item has not a supported type: %s", item) + return False + + # Prepare table items + tableItems = [ + qt.QTableWidgetItem(), # Legend + qt.QTableWidgetItem()] # Kind + + for column in range(2, self.columnCount()): + header = self.horizontalHeaderItem(column) + name = header.data(qt.Qt.UserRole) + + formatter = self._statsHandler.formatters[name] + if formatter: + tableItem = formatter.tabWidgetItemClass() else: - if self._statsHandler.formatters[name]: - _item = self._statsHandler.formatters[name].tabWidgetItemClass() - else: - _item = qt.QTableWidgetItem() - tooltip = self._statsHandler.stats[name].getToolTip(kind=kind) - if tooltip is not None: - _item.setToolTip(tooltip) + tableItem = qt.QTableWidgetItem() - _item.setFlags(qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable) - self.setItem(indexTable, self._columns_index[name], _item) - self._lgdAndKindToItems[(legend, kind)][name] = _item + tooltip = self._statsHandler.stats[name].getToolTip(kind=kind) + if tooltip is not None: + tableItem.setToolTip(tooltip) - return self._lgdAndKindToItems[(legend, kind)][name] + tableItems.append(tableItem) - def _removeItem(self, legend, kind): - if (legend, kind) not in self._lgdAndKindToItems or not self.plot: - return + # Disable sorting while adding table items + with self._disableSorting(): + # Add a row to the table + self.setRowCount(self.rowCount() + 1) - self.firstItem = self._lgdAndKindToItems[(legend, kind)]['legend'] - del self._lgdAndKindToItems[(legend, kind)] - self.removeRow(self.firstItem.row()) - self._legendsSet.remove(legend) - self.setColumnHidden(self._columns_index['kind'], - legend not in self._legendsSet) + # Add table items to the last row + row = self.rowCount() - 1 + for column, tableItem in enumerate(tableItems): + tableItem.setData(qt.Qt.UserRole, _Container(item)) + tableItem.setFlags( + qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable) + self.setItem(row, column, tableItem) - def _updateCurrentStats(self): - for lgdAndKind in self._lgdAndKindToItems: - self._updateStats(lgdAndKind[0], lgdAndKind[1]) + # Update table items content + self._updateStats(item) - def _updateStats(self, legend, kind, event=None): - if self._statsHandler is None: + # Listen for item changes + # Using queued connection to avoid issue with sender + # being that of the signal calling the signal + item.sigItemChanged.connect(self._plotItemChanged, + qt.Qt.QueuedConnection) + + return True + + def _removeItem(self, item): + """Remove table items corresponding to given plot item from the table. + + :param item: The plot item + """ + row = self._itemToRow(item) + if row is None: + kind = self._plotWrapper.getKind(item) + if kind in statsmdl.BASIC_COMPATIBLE_KINDS: + _logger.error("Removing item that is not in table: %s", str(item)) return + item.sigItemChanged.disconnect(self._plotItemChanged) + self.removeRow(row) + + def _removeAllItems(self): + """Remove content of the table""" + for row in range(self.rowCount()): + tableItem = self.item(row, 0) + item = self._tableItemToItem(tableItem) + item.sigItemChanged.disconnect(self._plotItemChanged) + self.clearContents() + self.setRowCount(0) - assert kind in ('curve', 'image', 'scatter', 'histogram') - if kind == 'curve': - item = self.plot.getCurve(legend) - elif kind == 'image': - item = self.plot.getImage(legend) - elif kind == 'scatter': - item = self.plot.getScatter(legend) - elif kind == 'histogram': - item = self.plot.getHistogram(legend) - else: - raise ValueError('kind not managed') + def _updateStats(self, item): + """Update displayed information for given plot item - if not item or (item.getLegend(), kind) not in self._lgdAndKindToItems: + :param item: The plot item + """ + if item is None: + return + plot = self.getPlot() + if plot is None: + _logger.info("Plot not available") return - assert isinstance(item, self.COMPATIBLE_ITEMS) - - statsValDict = self._statsHandler.calculate(item, self.plot, - self._statsOnVisibleData) - - lgdItem = self._lgdAndKindToItems[(item.getLegend(), kind)]['legend'] - assert lgdItem - rowStat = lgdItem.row() - - for statName, statVal in list(statsValDict.items()): - assert statName in self._lgdAndKindToItems[(item.getLegend(), kind)] - tableItem = self._getItem(name=statName, legend=item.getLegend(), - kind=kind, indexTable=rowStat) - tableItem.setText(str(statVal)) - - def currentChanged(self, current, previous): - if current.row() >= 0: - legendItem = self.item(current.row(), self._columns_index['legend']) - assert legendItem - kindItem = self.item(current.row(), self._columns_index['kind']) - kind = kindItem.text() - if kind == 'curve': - self.plot.setActiveCurve(legendItem.text()) - elif kind == 'image': - self.plot.setActiveImage(legendItem.text()) - elif kind == 'scatter': - self.plot._setActiveItem('scatter', legendItem.text()) - elif kind == 'histogram': - # active histogram not managed by the plot actually - pass - else: - raise ValueError('kind not managed') - qt.QTableWidget.currentChanged(self, current, previous) + row = self._itemToRow(item) + if row is None: + _logger.error("This item is not in the table: %s", str(item)) + return - def setDisplayOnlyActiveItem(self, displayOnlyActItem): + statsHandler = self.getStatsHandler() + if statsHandler is not None: + stats = statsHandler.calculate( + item, plot, self._statsOnVisibleData) + else: + stats = {} + + with self._disableSorting(): + for name, tableItem in self._itemToTableItems(item).items(): + if name == self._LEGEND_HEADER_DATA: + text = self._plotWrapper.getLabel(item) + tableItem.setText(text) + elif name == self._KIND_HEADER_DATA: + tableItem.setText(self._plotWrapper.getKind(item)) + else: + value = stats.get(name) + if value is None: + _logger.error("Value not found for: %s", name) + tableItem.setText('-') + else: + tableItem.setText(str(value)) + + def _updateAllStats(self): + """Update stats for all rows in the table""" + with self._disableSorting(): + for row in range(self.rowCount()): + tableItem = self.item(row, 0) + item = self._tableItemToItem(tableItem) + self._updateStats(item) + + def _currentItemChanged(self, current, previous): + """Handle change of selection in table and sync plot selection + + :param QTableWidgetItem current: + :param QTableWidgetItem previous: """ + if current and current.row() >= 0: + item = self._tableItemToItem(current) + self._plotWrapper.setCurrentItem(item) - :param bool displayOnlyActItem: True if we want to only show active - item + def setDisplayOnlyActiveItem(self, displayOnlyActItem): + """Toggle display off all items or only the active/selected one + + :param bool displayOnlyActItem: + True if we want to only show active item """ if self._displayOnlyActItem == displayOnlyActItem: return - self._displayOnlyActItem = displayOnlyActItem self._dealWithPlotConnection(create=False) + if not self._displayOnlyActItem: + self.currentItemChanged.disconnect(self._currentItemChanged) + + _StatsWidgetBase.setDisplayOnlyActiveItem(self, displayOnlyActItem) + self._updateItemObserve() self._dealWithPlotConnection(create=True) + if not self._displayOnlyActItem: + self.currentItemChanged.connect(self._currentItemChanged) + self.setSelectionMode(qt.QAbstractItemView.SingleSelection) + else: + self.setSelectionMode(qt.QAbstractItemView.NoSelection) + + +class _OptionsWidget(qt.QToolBar): + + def __init__(self, parent=None): + qt.QToolBar.__init__(self, parent) + self.setIconSize(qt.QSize(16, 16)) + + action = qt.QAction(self) + action.setIcon(icons.getQIcon("stats-active-items")) + action.setText("Active items only") + action.setToolTip("Display stats for active items only.") + action.setCheckable(True) + action.setChecked(True) + self.__displayActiveItems = action + + action = qt.QAction(self) + action.setIcon(icons.getQIcon("stats-whole-items")) + action.setText("All items") + action.setToolTip("Display stats for all available items.") + action.setCheckable(True) + self.__displayWholeItems = action + + action = qt.QAction(self) + action.setIcon(icons.getQIcon("stats-visible-data")) + action.setText("Use the visible data range") + action.setToolTip("Use the visible data range.<br/>" + "If activated the data is filtered to only use" + "visible data of the plot." + "The filtering is a data sub-sampling." + "No interpolation is made to fit data to" + "boundaries.") + action.setCheckable(True) + self.__useVisibleData = action + + action = qt.QAction(self) + action.setIcon(icons.getQIcon("stats-whole-data")) + action.setText("Use the full data range") + action.setToolTip("Use the full data range.") + action.setCheckable(True) + action.setChecked(True) + self.__useWholeData = action + + self.addAction(self.__displayWholeItems) + self.addAction(self.__displayActiveItems) + self.addSeparator() + self.addAction(self.__useVisibleData) + self.addAction(self.__useWholeData) + + self.itemSelection = qt.QActionGroup(self) + self.itemSelection.setExclusive(True) + self.itemSelection.addAction(self.__displayActiveItems) + self.itemSelection.addAction(self.__displayWholeItems) + + self.dataRangeSelection = qt.QActionGroup(self) + self.dataRangeSelection.setExclusive(True) + self.dataRangeSelection.addAction(self.__useWholeData) + self.dataRangeSelection.addAction(self.__useVisibleData) + + def isActiveItemMode(self): + return self.itemSelection.checkedAction() is self.__displayActiveItems + + def isVisibleDataRangeMode(self): + return self.dataRangeSelection.checkedAction() is self.__useVisibleData + + def setVisibleDataRangeModeEnabled(self, enabled): + """Enable/Disable the visible data range mode + + :param bool enabled: True to allow user to choose + stats on visible data + """ + self.__useVisibleData.setEnabled(enabled) + if not enabled: + self.__useWholeData.setChecked(True) + + +class StatsWidget(qt.QWidget): + """ + Widget displaying a set of :class:`Stat` to be displayed on a + :class:`StatsTable` and to be apply on items contained in the :class:`Plot` + Also contains options to: + + * compute statistics on all the data or on visible data only + * show statistics of all items or only the active one + + :param QWidget parent: Qt parent + :param Union[PlotWidget,SceneWidget] plot: + The plot containing items on which we want statistics. + :param StatsHandler stats: + Set the statistics to be displayed and how to format them using + """ + + sigVisibilityChanged = qt.Signal(bool) + """Signal emitted when the visibility of this widget changes. + + It Provides the visibility of the widget. + """ + + NUMBER_FORMAT = '{0:.3f}' + + def __init__(self, parent=None, plot=None, stats=None): + qt.QWidget.__init__(self, parent) + self.setLayout(qt.QVBoxLayout()) + self.layout().setContentsMargins(0, 0, 0, 0) + self._options = _OptionsWidget(parent=self) + self.layout().addWidget(self._options) + self._statsTable = StatsTable(parent=self, plot=plot) + self.setStats(stats) + + self.layout().addWidget(self._statsTable) + + self._options.itemSelection.triggered.connect( + self._optSelectionChanged) + self._options.dataRangeSelection.triggered.connect( + self._optDataRangeChanged) + self._optSelectionChanged() + self._optDataRangeChanged() + + def _getStatsTable(self): + """Returns the :class:`StatsTable` used by this widget. + + :rtype: StatsTable + """ + return self._statsTable + + def showEvent(self, event): + self.sigVisibilityChanged.emit(True) + qt.QWidget.showEvent(self, event) + + def hideEvent(self, event): + self.sigVisibilityChanged.emit(False) + qt.QWidget.hideEvent(self, event) + + def _optSelectionChanged(self, action=None): + self._getStatsTable().setDisplayOnlyActiveItem( + self._options.isActiveItemMode()) + + def _optDataRangeChanged(self, action=None): + self._getStatsTable().setStatsOnVisibleData( + self._options.isVisibleDataRangeMode()) + + # Proxy methods + + def setStats(self, statsHandler): + return self._getStatsTable().setStats(statsHandler=statsHandler) + + setStats.__doc__ = StatsTable.setStats.__doc__ + + def setPlot(self, plot): + self._options.setVisibleDataRangeModeEnabled( + plot is None or isinstance(plot, PlotWidget)) + return self._getStatsTable().setPlot(plot=plot) + + setPlot.__doc__ = StatsTable.setPlot.__doc__ + + def getPlot(self): + return self._getStatsTable().getPlot() + + getPlot.__doc__ = StatsTable.getPlot.__doc__ + + def setDisplayOnlyActiveItem(self, displayOnlyActItem): + return self._getStatsTable().setDisplayOnlyActiveItem( + displayOnlyActItem=displayOnlyActItem) + + setDisplayOnlyActiveItem.__doc__ = StatsTable.setDisplayOnlyActiveItem.__doc__ + def setStatsOnVisibleData(self, b): + return self._getStatsTable().setStatsOnVisibleData(b=b) + + setStatsOnVisibleData.__doc__ = StatsTable.setStatsOnVisibleData.__doc__ + + +DEFAULT_STATS = StatsHandler(( + (statsmdl.StatMin(), StatFormatter()), + statsmdl.StatCoordMin(), + (statsmdl.StatMax(), StatFormatter()), + statsmdl.StatCoordMax(), + statsmdl.StatCOM(), + (('mean', numpy.mean), StatFormatter()), + (('std', numpy.std), StatFormatter()), +)) + + +class BasicStatsWidget(StatsWidget): + """ + Widget defining a simple set of :class:`Stat` to be displayed on a + :class:`StatsWidget`. + + :param QWidget parent: Qt parent + :param PlotWidget plot: + The plot containing items on which we want statistics. + :param StatsHandler stats: + Set the statistics to be displayed and how to format them using + + .. snapshotqt:: img/BasicStatsWidget.png + :width: 300px + :align: center + + from silx.gui.plot import Plot1D + from silx.gui.plot.StatsWidget import BasicStatsWidget + + plot = Plot1D() + x = range(100) + y = x + plot.addCurve(x, y, legend='curve_0') + plot.setActiveCurve('curve_0') + + widget = BasicStatsWidget(plot=plot) + widget.show() + """ + def __init__(self, parent=None, plot=None): + StatsWidget.__init__(self, parent=parent, plot=plot, + stats=DEFAULT_STATS) + + +class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget): + """ + Widget made to display stats into a QLayout with for all stat a couple + (QLabel, QLineEdit) created. + The the layout can be defined prior of adding any statistic. + + :param QWidget parent: Qt parent + :param Union[PlotWidget,SceneWidget] plot: + The plot containing items on which we want statistics. + :param str kind: the kind of plotitems we want to display + :param StatsHandler stats: + Set the statistics to be displayed and how to format them using + :param bool statsOnVisibleData: compute statistics for the whole data or + only visible ones. + """ + + def __init__(self, parent=None, plot=None, kind='curve', stats=None, + statsOnVisibleData=False): + self._item_kind = kind + """The item displayed""" + self._statQlineEdit = {} + """list of legends actually displayed""" + self._n_statistics_per_line = 4 + """number of statistics displayed per line in the grid layout""" + qt.QWidget.__init__(self, parent) + _StatsWidgetBase.__init__(self, + statsOnVisibleData=statsOnVisibleData, + displayOnlyActItem=True) + self.setLayout(self._createLayout()) + self.setPlot(plot) + if stats is not None: + self.setStats(stats) + + def _addItemForStatistic(self, statistic): + assert isinstance(statistic, statsmdl.StatBase) + assert statistic.name in self._statsHandler.stats + + self.layout().setSpacing(2) + self.layout().setContentsMargins(2, 2, 2, 2) + + if isinstance(self.layout(), qt.QGridLayout): + parent = self + else: + widget = qt.QWidget(parent=self) + parent = widget + + qLabel = qt.QLabel(statistic.name + ':', parent=parent) + qLineEdit = qt.QLineEdit('', parent=parent) + qLineEdit.setReadOnly(True) + + self._addStatsWidgetsToLayout(qLabel=qLabel, qLineEdit=qLineEdit) + self._statQlineEdit[statistic.name] = qLineEdit + + def setPlot(self, plot): + """Define the plot to interact with + + :param Union[PlotWidget,SceneWidget,None] plot: + The plot containing the items on which statistics are applied """ - .. warning:: When visible data is activated we will process to a simple - filtering of visible data by the user. The filtering is a - simple data sub-sampling. No interpolation is made to fit - data to boundaries. + _StatsWidgetBase.setPlot(self, plot) + self._updateAllStats() - :param bool b: True if we want to apply statistics only on visible data + def _addStatsWidgetsToLayout(self, qLabel, qLineEdit): + raise NotImplementedError('Base class') + + def setStats(self, statsHandler): + """Set which stats to display and the associated formatting. + :param StatsHandler statsHandler: + Set the statistics to be displayed and how to format them using """ - if self._statsOnVisibleData != b: - self._statsOnVisibleData = b - self._updateCurrentStats() + _StatsWidgetBase.setStats(self, statsHandler) + for statName, stat in list(self._statsHandler.stats.items()): + self._addItemForStatistic(stat) + self._updateAllStats() def _activeItemChanged(self, kind, previous, current): - """Callback used when plotting only the active item""" - assert kind in ('curve', 'image', 'scatter', 'histogram') - self._updateItemObserve() + if kind == self._item_kind: + self._updateAllStats() - def _plotContentChanged(self, action, kind, legend): - """Callback used when plotting all the plot items""" - if kind not in ('curve', 'image', 'scatter', 'histogram'): - return - if kind == 'curve': - item = self.plot.getCurve(legend) - elif kind == 'image': - item = self.plot.getImage(legend) - elif kind == 'scatter': - item = self.plot.getScatter(legend) - elif kind == 'histogram': - item = self.plot.getHistogram(legend) - else: - raise ValueError('kind not managed') + def _updateAllStats(self): + plot = self.getPlot() + if plot is not None: + _items = self._plotWrapper.getSelectedItems() + def kind_filter(_item): + return self._plotWrapper.getKind(_item) == self.getKind() + + items = list(filter(kind_filter, _items)) + assert len(items) in (0, 1) + if len(items) is 1: + self._setItem(items[0]) + + def setKind(self, kind): + """Change the kind of active item to display + :param str kind: kind of item to display information for ('curve' ...) + """ + if self._item_kind != kind: + self._item_kind = kind + self._updateItemObserve() - if action == 'add': - if item is None: - raise ValueError('Item from legend "%s" do not exists' % legend) - self._addItem(item) - elif action == 'remove': - self._removeItem(legend, kind) + def getKind(self): + """ + :return: kind of item we want to compute statistic for + :rtype: str + """ + return self._item_kind + + def _setItem(self, item): + if item is None: + for stat_name, stat_widget in self._statQlineEdit.items(): + stat_widget.setText('') + elif (self._statsHandler is not None and len( + self._statsHandler.stats) > 0): + plot = self.getPlot() + if plot is not None: + statsValDict = self._statsHandler.calculate(item, + plot, + self._statsOnVisibleData) + for statName, statVal in list(statsValDict.items()): + self._statQlineEdit[statName].setText(statVal) + + def _updateItemObserve(self, *argv): + assert self._displayOnlyActItem + _items = self._plotWrapper.getSelectedItems() + def kind_filter(_item): + return self._plotWrapper.getKind(_item) == self.getKind() + items = list(filter(kind_filter, _items)) + assert len(items) in (0, 1) + _item = items[0] if len(items) is 1 else None + self._setItem(_item) + + def _createLayout(self): + """create an instance of the main QLayout""" + raise NotImplementedError('Base class') + + def _addItem(self, item): + raise NotImplementedError('Display only the active item') + + def _removeItem(self, item): + raise NotImplementedError('Display only the active item') + + def _plotCurrentChanged(selfself, current): + raise NotImplementedError('Display only the active item') + + +class BasicLineStatsWidget(_BaseLineStatsWidget): + """ + Widget defining a simple set of :class:`Stat` to be displayed on a + :class:`LineStatsWidget`. + + :param QWidget parent: Qt parent + :param Union[PlotWidget,SceneWidget] plot: + The plot containing items on which we want statistics. + :param str kind: the kind of plotitems we want to display + :param StatsHandler stats: + Set the statistics to be displayed and how to format them using + :param bool statsOnVisibleData: compute statistics for the whole data or + only visible ones. + """ + + def __init__(self, parent=None, plot=None, kind='curve', + stats=DEFAULT_STATS, statsOnVisibleData=False): + _BaseLineStatsWidget.__init__(self, parent=parent, kind=kind, + plot=plot, stats=stats, + statsOnVisibleData=statsOnVisibleData) + + def _createLayout(self): + return FlowLayout() + + def _addStatsWidgetsToLayout(self, qLabel, qLineEdit): + # create a mother widget to make sure both qLabel & qLineEdit will + # always be displayed side by side + widget = qt.QWidget(parent=self) + widget.setLayout(qt.QHBoxLayout()) + widget.layout().setSpacing(0) + widget.layout().setContentsMargins(0, 0, 0, 0) + + widget.layout().addWidget(qLabel) + widget.layout().addWidget(qLineEdit) + + self.layout().addWidget(widget) + + +class BasicGridStatsWidget(_BaseLineStatsWidget): + """ + pymca design like widget + + :param QWidget parent: Qt parent + :param Union[PlotWidget,SceneWidget] plot: + The plot containing items on which we want statistics. + :param StatsHandler stats: + Set the statistics to be displayed and how to format them using + :param str kind: the kind of plotitems we want to display + :param bool statsOnVisibleData: compute statistics for the whole data or + only visible ones. + :param int statsPerLine: number of statistic to be displayed per line + + .. snapshotqt:: img/BasicGridStatsWidget.png + :width: 600px + :align: center + + from silx.gui.plot import Plot1D + from silx.gui.plot.StatsWidget import BasicGridStatsWidget + + plot = Plot1D() + x = range(100) + y = x + plot.addCurve(x, y, legend='curve_0') + plot.setActiveCurve('curve_0') + + widget = BasicGridStatsWidget(plot=plot, kind='curve') + widget.show() + """ - def _zoomPlotChanged(self, event): - if self._statsOnVisibleData is True: - if 'event' in event and event['event'] == 'limitsChanged': - self._updateCurrentStats() + def __init__(self, parent=None, plot=None, kind='curve', + stats=DEFAULT_STATS, statsOnVisibleData=False, + statsPerLine=4): + _BaseLineStatsWidget.__init__(self, parent=parent, kind=kind, + plot=plot, stats=stats, + statsOnVisibleData=statsOnVisibleData) + self._n_statistics_per_line = statsPerLine + + def _addStatsWidgetsToLayout(self, qLabel, qLineEdit): + column = len(self._statQlineEdit) % self._n_statistics_per_line + row = len(self._statQlineEdit) // self._n_statistics_per_line + self.layout().addWidget(qLabel, row, column * 2) + self.layout().addWidget(qLineEdit, row, column * 2 + 1) + + def _createLayout(self): + return qt.QGridLayout() diff --git a/silx/gui/plot/_BaseMaskToolsWidget.py b/silx/gui/plot/_BaseMaskToolsWidget.py index e087354..0d11f17 100644 --- a/silx/gui/plot/_BaseMaskToolsWidget.py +++ b/silx/gui/plot/_BaseMaskToolsWidget.py @@ -29,7 +29,7 @@ from __future__ import division __authors__ = ["T. Vincent", "P. Knobel"] __license__ = "MIT" -__date__ = "29/08/2018" +__date__ = "15/02/2019" import os import weakref @@ -141,7 +141,7 @@ class BaseMask(qt.QObject): def commit(self): """Append the current mask to history if changed""" if (not self._history or self._redo or - not numpy.all(numpy.equal(self._mask, self._history[-1]))): + not numpy.array_equal(self._mask, self._history[-1])): if self._redo: self._redo = [] # Reset redo as a new action as been performed self.sigRedoable[bool].emit(False) @@ -325,7 +325,7 @@ class BaseMask(qt.QObject): raise NotImplementedError("To be implemented in subclass") def updateDisk(self, level, crow, ccol, radius, mask=True): - """Mask/Unmask data located inside a disk of the given mask level. + """Mask/Unmask data located inside a dick of the given mask level. :param int level: Mask level to update. :param crow: Disk center row/ordinate (y). @@ -335,6 +335,18 @@ class BaseMask(qt.QObject): """ raise NotImplementedError("To be implemented in subclass") + def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True): + """Mask/Unmask a disk of the given mask level. + + :param int level: Mask level to update. + :param int crow: Row of the center of the ellipse + :param int ccol: Column of the center of the ellipse + :param float radius_r: Radius of the ellipse in the row + :param float radius_c: Radius of the ellipse in the column + :param bool mask: True to mask (default), False to unmask. + """ + raise NotImplementedError("To be implemented in subclass") + def updateLine(self, level, row0, col0, row1, col1, width, mask=True): """Mask/Unmask a line of the given mask level. @@ -376,13 +388,11 @@ class BaseMaskToolsWidget(qt.QWidget): self._plotRef = weakref.ref(plot) self._maskName = '__MASK_TOOLS_%d' % id(self) # Legend of the mask - self._colormap = Colormap(name="", - normalization='linear', + self._colormap = Colormap(normalization='linear', vmin=0, - vmax=self._maxLevelNumber, - colors=None) + vmax=self._maxLevelNumber) self._defaultOverlayColor = rgba('gray') # Color of the mask - self._setMaskColors(1, 0.5) + self._setMaskColors(1, 0.5) # Set the colormap LUT if not isinstance(mask, BaseMask): raise TypeError("mask is not an instance of BaseMask") @@ -482,6 +492,7 @@ class BaseMaskToolsWidget(qt.QWidget): layout.addWidget(self._initMaskGroupBox()) layout.addWidget(self._initDrawGroupBox()) layout.addWidget(self._initThresholdGroupBox()) + layout.addWidget(self._initOtherToolsGroupBox()) layout.addStretch(1) self.setLayout(layout) @@ -617,6 +628,15 @@ class BaseMaskToolsWidget(qt.QWidget): self.rectAction.triggered.connect(self._activeRectMode) self.addAction(self.rectAction) + self.ellipseAction = qt.QAction( + icons.getQIcon('shape-ellipse'), 'Circle selection', None) + self.ellipseAction.setToolTip( + 'Rectangle selection tool: (Un)Mask a circle region <b>R</b>') + self.ellipseAction.setShortcut(qt.QKeySequence(qt.Qt.Key_R)) + self.ellipseAction.setCheckable(True) + self.ellipseAction.triggered.connect(self._activeEllipseMode) + self.addAction(self.ellipseAction) + self.polygonAction = qt.QAction( icons.getQIcon('shape-polygon'), 'Polygon selection', None) self.polygonAction.setShortcut(qt.QKeySequence(qt.Qt.Key_S)) @@ -640,10 +660,11 @@ class BaseMaskToolsWidget(qt.QWidget): self.drawActionGroup = qt.QActionGroup(self) self.drawActionGroup.setExclusive(True) self.drawActionGroup.addAction(self.rectAction) + self.drawActionGroup.addAction(self.ellipseAction) self.drawActionGroup.addAction(self.polygonAction) self.drawActionGroup.addAction(self.pencilAction) - actions = (self.browseAction, self.rectAction, + actions = (self.browseAction, self.rectAction, self.ellipseAction, self.polygonAction, self.pencilAction) drawButtons = [] for action in actions: @@ -711,36 +732,28 @@ class BaseMaskToolsWidget(qt.QWidget): def _initThresholdGroupBox(self): """Init thresholding widgets""" - layout = qt.QVBoxLayout() - - # Thresholing self.belowThresholdAction = qt.QAction( icons.getQIcon('plot-roi-below'), 'Mask below threshold', None) self.belowThresholdAction.setToolTip( 'Mask image where values are below given threshold') self.belowThresholdAction.setCheckable(True) - self.belowThresholdAction.triggered[bool].connect( - self._belowThresholdActionTriggered) + self.belowThresholdAction.setChecked(True) self.betweenThresholdAction = qt.QAction( icons.getQIcon('plot-roi-between'), 'Mask within range', None) self.betweenThresholdAction.setToolTip( 'Mask image where values are within given range') self.betweenThresholdAction.setCheckable(True) - self.betweenThresholdAction.triggered[bool].connect( - self._betweenThresholdActionTriggered) self.aboveThresholdAction = qt.QAction( icons.getQIcon('plot-roi-above'), 'Mask above threshold', None) self.aboveThresholdAction.setToolTip( 'Mask image where values are above given threshold') self.aboveThresholdAction.setCheckable(True) - self.aboveThresholdAction.triggered[bool].connect( - self._aboveThresholdActionTriggered) self.thresholdActionGroup = qt.QActionGroup(self) - self.thresholdActionGroup.setExclusive(False) + self.thresholdActionGroup.setExclusive(True) self.thresholdActionGroup.addAction(self.belowThresholdAction) self.thresholdActionGroup.addAction(self.betweenThresholdAction) self.thresholdActionGroup.addAction(self.aboveThresholdAction) @@ -770,41 +783,50 @@ class BaseMaskToolsWidget(qt.QWidget): loadColormapRangeBtn.setDefaultAction(self.loadColormapRangeAction) widgets.append(loadColormapRangeBtn) - container = self._hboxWidget(*widgets, stretch=False) - layout.addWidget(container) + toolBar = self._hboxWidget(*widgets, stretch=False) - form = qt.QFormLayout() + config = qt.QGridLayout() + config.setContentsMargins(0, 0, 0, 0) + self.minLineLabel = qt.QLabel("Min:", self) self.minLineEdit = FloatEdit(self, value=0) - self.minLineEdit.setEnabled(False) - form.addRow('Min:', self.minLineEdit) + config.addWidget(self.minLineLabel, 0, 0) + config.addWidget(self.minLineEdit, 0, 1) + self.maxLineLabel = qt.QLabel("Max:", self) self.maxLineEdit = FloatEdit(self, value=0) - self.maxLineEdit.setEnabled(False) - form.addRow('Max:', self.maxLineEdit) + config.addWidget(self.maxLineLabel, 1, 0) + config.addWidget(self.maxLineEdit, 1, 1) self.applyMaskBtn = qt.QPushButton('Apply mask') self.applyMaskBtn.clicked.connect(self._maskBtnClicked) - self.applyMaskBtn.setEnabled(False) - form.addRow(self.applyMaskBtn) - - self.maskNanBtn = qt.QPushButton('Mask not finite values') - self.maskNanBtn.setToolTip('Mask Not a Number and infinite values') - self.maskNanBtn.clicked.connect(self._maskNotFiniteBtnClicked) - form.addRow(self.maskNanBtn) - thresholdWidget = qt.QWidget() - thresholdWidget.setLayout(form) - layout.addWidget(thresholdWidget) - - layout.addStretch(1) + layout = qt.QVBoxLayout() + layout.addWidget(toolBar) + layout.addLayout(config) + layout.addWidget(self.applyMaskBtn) self.thresholdGroup = qt.QGroupBox('Threshold') self.thresholdGroup.setLayout(layout) + + # Init widget state + self._thresholdActionGroupTriggered(self.belowThresholdAction) return self.thresholdGroup # track widget visibility and plot active image changes + def _initOtherToolsGroupBox(self): + layout = qt.QVBoxLayout() + + self.maskNanBtn = qt.QPushButton('Mask not finite values') + self.maskNanBtn.setToolTip('Mask Not a Number and infinite values') + self.maskNanBtn.clicked.connect(self._maskNotFiniteBtnClicked) + layout.addWidget(self.maskNanBtn) + + self.otherToolGroup = qt.QGroupBox('Other tools') + self.otherToolGroup.setLayout(layout) + return self.otherToolGroup + def changeEvent(self, event): """Reset drawing action when disabling widget""" if (event.type() == qt.QEvent.EnabledChange and @@ -883,6 +905,7 @@ class BaseMaskToolsWidget(qt.QWidget): The index of the mask for which we want to change the color. If none set this color for all the masks """ + rgb = rgba(rgb)[0:3] if level is None: self._overlayColors[:] = rgb self._defaultColors[:] = False @@ -925,6 +948,8 @@ class BaseMaskToolsWidget(qt.QWidget): """ if self._drawingMode == 'rectangle': self._activeRectMode() + elif self._drawingMode == 'ellipse': + self._activeEllipseMode() elif self._drawingMode == 'polygon': self._activePolygonMode() elif self._drawingMode == 'pencil': @@ -971,6 +996,16 @@ class BaseMaskToolsWidget(qt.QWidget): 'draw', shape='rectangle', source=self, color=color) self._updateDrawingModeWidgets() + def _activeEllipseMode(self): + """Handle circle action mode triggering""" + self._releaseDrawingMode() + self._drawingMode = 'ellipse' + self.plot.sigPlotSignal.connect(self._plotDrawEvent) + color = self.getCurrentMaskColor() + self.plot.setInteractiveMode( + 'draw', shape='ellipse', source=self, color=color) + self._updateDrawingModeWidgets() + def _activePolygonMode(self): """Handle polygon action mode triggering""" self._releaseDrawingMode() @@ -1016,36 +1051,28 @@ class BaseMaskToolsWidget(qt.QWidget): return doMask # Handle threshold UI events - def _belowThresholdActionTriggered(self, triggered): - if triggered: - self.minLineEdit.setEnabled(True) - self.maxLineEdit.setEnabled(False) - self.applyMaskBtn.setEnabled(True) - - def _betweenThresholdActionTriggered(self, triggered): - if triggered: - self.minLineEdit.setEnabled(True) - self.maxLineEdit.setEnabled(True) - self.applyMaskBtn.setEnabled(True) - - def _aboveThresholdActionTriggered(self, triggered): - if triggered: - self.minLineEdit.setEnabled(False) - self.maxLineEdit.setEnabled(True) - self.applyMaskBtn.setEnabled(True) def _thresholdActionGroupTriggered(self, triggeredAction): """Threshold action group listener.""" - if triggeredAction.isChecked(): - # Uncheck other actions - for action in self.thresholdActionGroup.actions(): - if action is not triggeredAction and action.isChecked(): - action.setChecked(False) - else: - # Disable min/max edit - self.minLineEdit.setEnabled(False) - self.maxLineEdit.setEnabled(False) - self.applyMaskBtn.setEnabled(False) + if triggeredAction is self.belowThresholdAction: + self.minLineLabel.setVisible(True) + self.maxLineLabel.setVisible(False) + self.minLineEdit.setVisible(True) + self.maxLineEdit.setVisible(False) + self.applyMaskBtn.setText("Mask bellow") + elif triggeredAction is self.betweenThresholdAction: + self.minLineLabel.setVisible(True) + self.maxLineLabel.setVisible(True) + self.minLineEdit.setVisible(True) + self.maxLineEdit.setVisible(True) + self.applyMaskBtn.setText("Mask between") + elif triggeredAction is self.aboveThresholdAction: + self.minLineLabel.setVisible(False) + self.maxLineLabel.setVisible(True) + self.minLineEdit.setVisible(False) + self.maxLineEdit.setVisible(True) + self.applyMaskBtn.setText("Mask above") + self.applyMaskBtn.setToolTip(triggeredAction.toolTip()) def _maskBtnClicked(self): if self.belowThresholdAction.isChecked(): diff --git a/silx/gui/plot/_utils/dtime_ticklayout.py b/silx/gui/plot/_utils/dtime_ticklayout.py index 95fc235..23c9dce 100644 --- a/silx/gui/plot/_utils/dtime_ticklayout.py +++ b/silx/gui/plot/_utils/dtime_ticklayout.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# Copyright (c) 2014-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -32,6 +32,7 @@ __date__ = "04/04/2018" import datetime as dt +import enum import logging import math import time @@ -40,7 +41,6 @@ import dateutil.tz from dateutil.relativedelta import relativedelta -from silx.third_party import enum from .ticklayout import niceNumGeneric _logger = logging.getLogger(__name__) diff --git a/silx/gui/plot/actions/control.py b/silx/gui/plot/actions/control.py index 10df130..2d01ef1 100644 --- a/silx/gui/plot/actions/control.py +++ b/silx/gui/plot/actions/control.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2018 European Synchrotron Radiation Facility +# Copyright (c) 2004-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -303,9 +303,12 @@ class CurveStyleAction(PlotAction): currentState = (self.plot.isDefaultPlotLines(), self.plot.isDefaultPlotPoints()) - # line only, line and symbol, symbol only - states = (True, False), (True, True), (False, True) - newState = states[(states.index(currentState) + 1) % 3] + if currentState == (False, False): + newState = True, False + else: + # line only, line and symbol, symbol only + states = (True, False), (True, True), (False, True) + newState = states[(states.index(currentState) + 1) % 3] self.plot.setDefaultPlotLines(newState[0]) self.plot.setDefaultPlotPoints(newState[1]) diff --git a/silx/gui/plot/actions/io.py b/silx/gui/plot/actions/io.py index 97de527..09e4a99 100644 --- a/silx/gui/plot/actions/io.py +++ b/silx/gui/plot/actions/io.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2018 European Synchrotron Radiation Facility +# Copyright (c) 2004-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -502,7 +502,7 @@ class SaveAction(PlotAction): axes_errors=[xerror, yerror], title=plot.getGraphTitle()) - def setFileFilter(self, dataKind, nameFilter, func): + def setFileFilter(self, dataKind, nameFilter, func, index=None): """Set a name filter to add/replace a file format support :param str dataKind: @@ -513,10 +513,44 @@ class SaveAction(PlotAction): :param callable func: The function to call to perform saving. Expected signature is: bool func(PlotWidget plot, str filename, str nameFilter) + :param integer index: Index of the filter in the final list (or None) """ assert dataKind in ('all', 'curve', 'curves', 'image', 'scatter') + # first append or replace the new filter to prevent colissions self._filters[dataKind][nameFilter] = func + if index is None: + # we are already done + return + + # get the current ordered list of keys + keyList = list(self._filters[dataKind].keys()) + + # deal with negative indices + if index < 0: + index = len(keyList) + index + if index < 0: + index = 0 + + if index >= len(keyList): + # nothing to be done, already at the end + txt = 'Requested index %d impossible, already at the end' % index + _logger.info(txt) + return + + # get the new ordered list + oldIndex = keyList.index(nameFilter) + del keyList[oldIndex] + keyList.insert(index, nameFilter) + + # build the new filters + newFilters = OrderedDict() + for key in keyList: + newFilters[key] = self._filters[dataKind][key] + + # and update the filters + self._filters[dataKind] = newFilters + return def getFileFilters(self, dataKind): """Returns the nameFilter and associated function for a kind of data. diff --git a/silx/gui/plot/backends/BackendBase.py b/silx/gui/plot/backends/BackendBase.py index 7fb8be0..0514c85 100644 --- a/silx/gui/plot/backends/BackendBase.py +++ b/silx/gui/plot/backends/BackendBase.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2018 European Synchrotron Radiation Facility +# Copyright (c) 2004-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -31,7 +31,7 @@ This API is a simplified version of PyMca PlotBackend API. __authors__ = ["V.A. Sole", "T. Vincent"] __license__ = "MIT" -__date__ = "24/04/2018" +__date__ = "21/12/2018" import weakref from ... import qt @@ -170,7 +170,8 @@ class BackendBase(object): """ return legend - def addItem(self, x, y, legend, shape, color, fill, overlay, z): + def addItem(self, x, y, legend, shape, color, fill, overlay, z, + linestyle, linewidth, linebgcolor): """Add an item (i.e. a shape) to the plot. :param numpy.ndarray x: The X coords of the points of the shape @@ -182,6 +183,19 @@ class BackendBase(object): :param bool fill: True to fill the shape :param bool overlay: True if item is an overlay, False otherwise :param int z: Layer on which to draw the item + :param str linestyle: Style of the line. + Only relevant for line markers where X or Y is None. + Value in: + + - ' ' no line + - '-' solid line + - '--' dashed line + - '-.' dash-dot line + - ':' dotted line + :param float linewidth: Width of the line. + Only relevant for line markers where X or Y is None. + :param str linebgcolor: Background color of the line, e.g., 'blue', 'b', + '#FF0000'. It is used to draw dotted line using a second color. :returns: The handle used by the backend to univocally access the item """ return legend @@ -546,3 +560,20 @@ class BackendBase(object): This only check status set to axes from the public API """ return self._axesDisplayed + + def setForegroundColors(self, foregroundColor, gridColor): + """Set foreground and grid colors used to display this widget. + + :param List[float] foregroundColor: RGBA foreground color of the widget + :param List[float] gridColor: RGBA grid color of the data view + """ + pass + + def setBackgroundColors(self, backgroundColor, dataBackgroundColor): + """Set background colors used to display this widget. + + :param List[float] backgroundColor: RGBA background color of the widget + :param Union[Tuple[float],None] dataBackgroundColor: + RGBA background color of the data view + """ + pass diff --git a/silx/gui/plot/backends/BackendMatplotlib.py b/silx/gui/plot/backends/BackendMatplotlib.py index 3b1d6dd..726a839 100644 --- a/silx/gui/plot/backends/BackendMatplotlib.py +++ b/silx/gui/plot/backends/BackendMatplotlib.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2018 European Synchrotron Radiation Facility +# Copyright (c) 2004-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -28,7 +28,7 @@ from __future__ import division __authors__ = ["V.A. Sole", "T. Vincent, H. Payno"] __license__ = "MIT" -__date__ = "01/08/2018" +__date__ = "21/12/2018" import logging @@ -56,12 +56,26 @@ from matplotlib.collections import PathCollection, LineCollection from matplotlib.ticker import Formatter, ScalarFormatter, Locator -from ....third_party.modest_image import ModestImage from . import BackendBase from .._utils import FLOAT32_MINPOS from .._utils.dtime_ticklayout import calcTicks, bestFormatString, timestamp +_PATCH_LINESTYLE = { + "-": 'solid', + "--": 'dashed', + '-.': 'dashdot', + ':': 'dotted', + '': "solid", + None: "solid", +} +"""Patches do not uses the same matplotlib syntax""" + + +def normalize_linestyle(linestyle): + """Normalize known old-style linestyle, else return the provided value.""" + return _PATCH_LINESTYLE.get(linestyle, linestyle) + class NiceDateLocator(Locator): """ @@ -115,7 +129,6 @@ class NiceDateLocator(Locator): return ticks - class NiceAutoDateFormatter(Formatter): """ Matplotlib FuncFormatter that is linked to a NiceDateLocator and gives the @@ -139,7 +152,6 @@ class NiceAutoDateFormatter(Formatter): else: return bestFormatString(self.locator.spacing, self.locator.unit) - def __call__(self, x, pos=None): """Return the format for tick val *x* at position *pos* Expects x to be a POSIX timestamp (seconds since 1 Jan 1970) @@ -149,8 +161,6 @@ class NiceAutoDateFormatter(Formatter): return tickStr - - class _MarkerContainer(Container): """Marker artists container supporting draw/remove and text position update @@ -204,6 +214,57 @@ class _MarkerContainer(Container): self.text.set_x(xmax) +class _DoubleColoredLinePatch(matplotlib.patches.Patch): + """Matplotlib patch to display any patch using double color.""" + + def __init__(self, patch): + super(_DoubleColoredLinePatch, self).__init__() + self.__patch = patch + self.linebgcolor = None + + def __getattr__(self, name): + return getattr(self.__patch, name) + + def draw(self, renderer): + oldLineStype = self.__patch.get_linestyle() + if self.linebgcolor is not None and oldLineStype != "solid": + oldLineColor = self.__patch.get_edgecolor() + oldHatch = self.__patch.get_hatch() + self.__patch.set_linestyle("solid") + self.__patch.set_edgecolor(self.linebgcolor) + self.__patch.set_hatch(None) + self.__patch.draw(renderer) + self.__patch.set_linestyle(oldLineStype) + self.__patch.set_edgecolor(oldLineColor) + self.__patch.set_hatch(oldHatch) + self.__patch.draw(renderer) + + def set_transform(self, transform): + self.__patch.set_transform(transform) + + def get_path(self): + return self.__patch.get_path() + + def contains(self, mouseevent, radius=None): + return self.__patch.contains(mouseevent, radius) + + def contains_point(self, point, radius=None): + return self.__patch.contains_point(point, radius) + + +class Image(AxesImage): + """An AxesImage with a fast path for uint8 RGBA images""" + + def set_data(self, A): + A = numpy.array(A, copy=False) + if A.ndim != 3 or A.shape[2] != 4 or A.dtype != numpy.uint8: + super(Image, self).set_data(A) + else: + # Call AxesImage.set_data with small data to set attributes + super(Image, self).set_data(numpy.zeros((2, 2, 4), dtype=A.dtype)) + self._A = A # Override stored data + + class BackendMatplotlib(BackendBase.BackendBase): """Base class for Matplotlib backend without a FigureCanvas. @@ -231,6 +292,8 @@ class BackendMatplotlib(BackendBase.BackendBase): self.ax = self.fig.add_axes([.15, .15, .75, .75], label="left") self.ax2 = self.ax.twinx() self.ax2.set_label("right") + # Make sure background of Axes is displayed + self.ax2.patch.set_visible(True) # disable the use of offsets try: @@ -239,9 +302,9 @@ class BackendMatplotlib(BackendBase.BackendBase): self.ax2.get_yaxis().get_major_formatter().set_useOffset(False) self.ax2.get_xaxis().get_major_formatter().set_useOffset(False) except: - _logger.warning('Cannot disabled axes offsets in %s ' \ + _logger.warning('Cannot disabled axes offsets in %s ' % matplotlib.__version__) - + # critical for picking!!!! self.ax2.set_zorder(0) self.ax2.set_autoscaley_on(True) @@ -376,44 +439,13 @@ class BackendMatplotlib(BackendBase.BackendBase): picker = (selectable or draggable) - # Debian 7 specific support - # No transparent colormap with matplotlib < 1.2.0 - # Add support for transparent colormap for uint8 data with - # colormap with 256 colors, linear norm, [0, 255] range - if self._matplotlibVersion < _parse_version('1.2.0'): - if (len(data.shape) == 2 and colormap.getName() is None and - colormap.getColormapLUT() is not None): - colors = colormap.getColormapLUT() - if (colors.shape[-1] == 4 and - not numpy.all(numpy.equal(colors[3], 255))): - # This is a transparent colormap - if (colors.shape == (256, 4) and - colormap.getNormalization() == 'linear' and - not colormap.isAutoscale() and - colormap.getVMin() == 0 and - colormap.getVMax() == 255 and - data.dtype == numpy.uint8): - # Supported case, convert data to RGBA - data = colors[data.reshape(-1)].reshape( - data.shape + (4,)) - else: - _logger.warning( - 'matplotlib %s does not support transparent ' - 'colormap.', matplotlib.__version__) - - if ((height * width) > 5.0e5 and - origin == (0., 0.) and scale == (1., 1.)): - imageClass = ModestImage - else: - imageClass = AxesImage - # All image are shown as RGBA image - image = imageClass(self.ax, - label="__IMAGE__" + legend, - interpolation='nearest', - picker=picker, - zorder=z, - origin='lower') + image = Image(self.ax, + label="__IMAGE__" + legend, + interpolation='nearest', + picker=picker, + zorder=z, + origin='lower') if alpha < 1: image.set_alpha(alpha) @@ -438,40 +470,41 @@ class BackendMatplotlib(BackendBase.BackendBase): ystep = 1 if scale[1] >= 0. else -1 data = data[::ystep, ::xstep] - if self._matplotlibVersion < _parse_version('2.1'): - # matplotlib 1.4.2 do not support float128 - dtype = data.dtype - if dtype.kind == "f" and dtype.itemsize >= 16: - _logger.warning("Your matplotlib version do not support " - "float128. Data converted to float64.") - data = data.astype(numpy.float64) - if data.ndim == 2: # Data image, convert to RGBA image data = colormap.applyToData(data) image.set_data(data) - self.ax.add_artist(image) - return image - def addItem(self, x, y, legend, shape, color, fill, overlay, z): + def addItem(self, x, y, legend, shape, color, fill, overlay, z, + linestyle, linewidth, linebgcolor): + if (linebgcolor is not None and + shape not in ('rectangle', 'polygon', 'polylines')): + _logger.warning( + 'linebgcolor not implemented for %s with matplotlib backend', + shape) xView = numpy.array(x, copy=False) yView = numpy.array(y, copy=False) + linestyle = normalize_linestyle(linestyle) + if shape == "line": item = self.ax.plot(x, y, label=legend, color=color, - linestyle='-', marker=None)[0] + linestyle=linestyle, linewidth=linewidth, + marker=None)[0] elif shape == "hline": if hasattr(y, "__len__"): y = y[-1] - item = self.ax.axhline(y, label=legend, color=color) + item = self.ax.axhline(y, label=legend, color=color, + linestyle=linestyle, linewidth=linewidth) elif shape == "vline": if hasattr(x, "__len__"): x = x[-1] - item = self.ax.axvline(x, label=legend, color=color) + item = self.ax.axvline(x, label=legend, color=color, + linestyle=linestyle, linewidth=linewidth) elif shape == 'rectangle': xMin = numpy.nanmin(xView) @@ -484,10 +517,16 @@ class BackendMatplotlib(BackendBase.BackendBase): width=w, height=h, fill=False, - color=color) + color=color, + linestyle=linestyle, + linewidth=linewidth) if fill: item.set_hatch('.') + if linestyle != "solid" and linebgcolor is not None: + item = _DoubleColoredLinePatch(item) + item.linebgcolor = linebgcolor + self.ax.add_patch(item) elif shape in ('polygon', 'polylines'): @@ -500,10 +539,16 @@ class BackendMatplotlib(BackendBase.BackendBase): closed=closed, fill=False, label=legend, - color=color) + color=color, + linestyle=linestyle, + linewidth=linewidth) if fill and shape == 'polygon': item.set_hatch('/') + if linestyle != "solid" and linebgcolor is not None: + item = _DoubleColoredLinePatch(item) + item.linebgcolor = linebgcolor + self.ax.add_patch(item) else: @@ -908,8 +953,56 @@ class BackendMatplotlib(BackendBase.BackendBase): # remove external margins self.ax.set_position([0, 0, 1, 1]) self.ax2.set_position([0, 0, 1, 1]) + self._synchronizeBackgroundColors() + self._synchronizeForegroundColors() self._plot._setDirtyPlot() + def _synchronizeBackgroundColors(self): + backgroundColor = self._plot.getBackgroundColor().getRgbF() + + dataBackgroundColor = self._plot.getDataBackgroundColor() + if dataBackgroundColor.isValid(): + dataBackgroundColor = dataBackgroundColor.getRgbF() + else: + dataBackgroundColor = backgroundColor + + if self.ax2.axison: + self.fig.patch.set_facecolor(backgroundColor) + if self._matplotlibVersion < _parse_version('2'): + self.ax2.set_axis_bgcolor(dataBackgroundColor) + else: + self.ax2.set_facecolor(dataBackgroundColor) + else: + self.fig.patch.set_facecolor(dataBackgroundColor) + + def _synchronizeForegroundColors(self): + foregroundColor = self._plot.getForegroundColor().getRgbF() + + gridColor = self._plot.getGridColor() + if gridColor.isValid(): + gridColor = gridColor.getRgbF() + else: + gridColor = foregroundColor + + for axes in (self.ax, self.ax2): + if axes.axison: + axes.spines['bottom'].set_color(foregroundColor) + axes.spines['top'].set_color(foregroundColor) + axes.spines['right'].set_color(foregroundColor) + axes.spines['left'].set_color(foregroundColor) + axes.tick_params(axis='x', colors=foregroundColor) + axes.tick_params(axis='y', colors=foregroundColor) + axes.yaxis.label.set_color(foregroundColor) + axes.xaxis.label.set_color(foregroundColor) + axes.title.set_color(foregroundColor) + + for line in axes.get_xgridlines(): + line.set_color(gridColor) + + for line in axes.get_ygridlines(): + line.set_color(gridColor) + # axes.grid().set_markeredgecolor(gridColor) + class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib): """QWidget matplotlib backend using a QtAgg canvas. @@ -1137,3 +1230,9 @@ class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib): else: cursor = self._QT_CURSORS[cursor] FigureCanvasQTAgg.setCursor(self, qt.QCursor(cursor)) + + def setBackgroundColors(self, backgroundColor, dataBackgroundColor): + self._synchronizeBackgroundColors() + + def setForegroundColors(self, foregroundColor, gridColor): + self._synchronizeForegroundColors() diff --git a/silx/gui/plot/backends/BackendOpenGL.py b/silx/gui/plot/backends/BackendOpenGL.py index 9e2cb73..e33d03c 100644 --- a/silx/gui/plot/backends/BackendOpenGL.py +++ b/silx/gui/plot/backends/BackendOpenGL.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2018 European Synchrotron Radiation Facility +# Copyright (c) 2014-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -28,7 +28,7 @@ from __future__ import division __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "01/08/2018" +__date__ = "21/12/2018" from collections import OrderedDict, namedtuple from ctypes import c_void_p @@ -44,10 +44,11 @@ from ... import qt from ..._glutils import gl from ... import _glutils as glu from .glutils import ( + GLLines2D, GLPlotCurve2D, GLPlotColormap, GLPlotRGBAImage, GLPlotFrame2D, mat4Ortho, mat4Identity, LEFT, RIGHT, BOTTOM, TOP, - Text2D, Shape2D) + Text2D, FilledShape2D) from .glutils.PlotImageFile import saveImageToFile _logger = logging.getLogger(__name__) @@ -338,6 +339,9 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): f=f) BackendBase.BackendBase.__init__(self, plot, parent) + self._backgroundColor = 1., 1., 1., 1. + self._dataBackgroundColor = 1., 1., 1., 1. + self.matScreenProj = mat4Identity() self._progBase = glu.Program( @@ -357,6 +361,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): self._glGarbageCollector = [] self._plotFrame = GLPlotFrame2D( + foregroundColor=(0., 0., 0., 1.), + gridColor=(.7, .7, .7, 1.), margins={'left': 100, 'right': 50, 'top': 50, 'bottom': 50}) # Make postRedisplay asynchronous using Qt signal @@ -432,7 +438,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): def initializeGL(self): gl.testGL() - gl.glClearColor(1., 1., 1., 1.) gl.glClearStencil(0) gl.glEnable(gl.GL_BLEND) @@ -482,6 +487,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): self._plotFBOs[context] = plotFBOTex with plotFBOTex: + gl.glClearColor(*self._backgroundColor) gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT) self._renderPlotAreaGL() self._plotFrame.render() @@ -530,6 +536,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): item.discard() self._glGarbageCollector = [] + gl.glClearColor(*self._backgroundColor) gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT) # Check if window is large enough @@ -543,100 +550,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): glu.setGLContextGetter() _current_context = None - def _nonOrthoAxesLineMarkerPrimitives(self, marker, pixelOffset): - """Generates the vertices and label for a line marker. - - :param dict marker: Description of a line marker - :param int pixelOffset: Offset of text from borders in pixels - :return: Line vertices and Text label or None - :rtype: 2-tuple (2x2 numpy.array of float, Text2D) - """ - label, vertices = None, None - - xCoord, yCoord = marker['x'], marker['y'] - assert xCoord is None or yCoord is None # Specific to line markers - - # Get plot corners in data coords - plotLeft, plotTop, plotWidth, plotHeight = self.getPlotBoundsInPixels() - - corners = [(plotLeft, plotTop), - (plotLeft, plotTop + plotHeight), - (plotLeft + plotWidth, plotTop + plotHeight), - (plotLeft + plotWidth, plotTop)] - corners = numpy.array([self.pixelToData(x, y, axis='left', check=False) - for (x, y) in corners]) - - borders = { - 'right': (corners[3], corners[2]), - 'top': (corners[0], corners[3]), - 'bottom': (corners[2], corners[1]), - 'left': (corners[1], corners[0]) - } - - textLayouts = { # align, valign, offsets - 'right': (RIGHT, BOTTOM, (-1., -1.)), - 'top': (LEFT, TOP, (1., 1.)), - 'bottom': (LEFT, BOTTOM, (1., -1.)), - 'left': (LEFT, BOTTOM, (1., -1.)) - } - - if xCoord is None: # Horizontal line in data space - if marker['text'] is not None: - # Find intersection of hline with borders in data - # Order is important as it stops at first intersection - for border_name in ('right', 'top', 'bottom', 'left'): - (x0, y0), (x1, y1) = borders[border_name] - - if min(y0, y1) <= yCoord < max(y0, y1): - xIntersect = (yCoord - y0) * (x1 - x0) / (y1 - y0) + x0 - - # Add text label - pixelPos = self.dataToPixel( - xIntersect, yCoord, axis='left', check=False) - - align, valign, offsets = textLayouts[border_name] - - x = pixelPos[0] + offsets[0] * pixelOffset - y = pixelPos[1] + offsets[1] * pixelOffset - label = Text2D(marker['text'], x, y, - color=marker['color'], - bgColor=(1., 1., 1., 0.5), - align=align, valign=valign) - break # Stop at first intersection - - xMin, xMax = corners[:, 0].min(), corners[:, 0].max() - vertices = numpy.array( - ((xMin, yCoord), (xMax, yCoord)), dtype=numpy.float32) - - else: # yCoord is None: vertical line in data space - if marker['text'] is not None: - # Find intersection of hline with borders in data - # Order is important as it stops at first intersection - for border_name in ('top', 'bottom', 'right', 'left'): - (x0, y0), (x1, y1) = borders[border_name] - if min(x0, x1) <= xCoord < max(x0, x1): - yIntersect = (xCoord - x0) * (y1 - y0) / (x1 - x0) + y0 - - # Add text label - pixelPos = self.dataToPixel( - xCoord, yIntersect, axis='left', check=False) - - align, valign, offsets = textLayouts[border_name] - - x = pixelPos[0] + offsets[0] * pixelOffset - y = pixelPos[1] + offsets[1] * pixelOffset - label = Text2D(marker['text'], x, y, - color=marker['color'], - bgColor=(1., 1., 1., 0.5), - align=align, valign=valign) - break # Stop at first intersection - - yMin, yMax = corners[:, 1].min(), corners[:, 1].max() - vertices = numpy.array( - ((xCoord, yMin), (xCoord, yMax)), dtype=numpy.float32) - - return vertices, label - def _renderMarkersGL(self): if len(self._markers) == 0: return @@ -651,16 +564,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1]) - # Prepare vertical and horizontal markers rendering - self._progBase.use() - gl.glUniformMatrix4fv( - self._progBase.uniforms['matrix'], 1, gl.GL_TRUE, - self.matScreenProj.astype(numpy.float32)) - gl.glUniform2i(self._progBase.uniforms['isLog'], False, False) - gl.glUniform1i(self._progBase.uniforms['hatchStep'], 0) - gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.) - posAttrib = self._progBase.attributes['position'] - labels = [] pixelOffset = 3 @@ -677,59 +580,43 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): continue if xCoord is None or yCoord is None: - if not self.isDefaultBaseVectors(): # Non-orthogonal axes - vertices, label = self._nonOrthoAxesLineMarkerPrimitives( - marker, pixelOffset) - if label is not None: - labels.append(label) + pixelPos = self.dataToPixel( + xCoord, yCoord, axis='left', check=False) - else: # Orthogonal axes - pixelPos = self.dataToPixel( - xCoord, yCoord, axis='left', check=False) - - if xCoord is None: # Horizontal line in data space - if marker['text'] is not None: - x = self._plotFrame.size[0] - \ - self._plotFrame.margins.right - pixelOffset - y = pixelPos[1] - pixelOffset - label = Text2D(marker['text'], x, y, - color=marker['color'], - bgColor=(1., 1., 1., 0.5), - align=RIGHT, valign=BOTTOM) - labels.append(label) - - width = self._plotFrame.size[0] - vertices = numpy.array(((0, pixelPos[1]), - (width, pixelPos[1])), - dtype=numpy.float32) - - else: # yCoord is None: vertical line in data space - if marker['text'] is not None: - x = pixelPos[0] + pixelOffset - y = self._plotFrame.margins.top + pixelOffset - label = Text2D(marker['text'], x, y, - color=marker['color'], - bgColor=(1., 1., 1., 0.5), - align=LEFT, valign=TOP) - labels.append(label) - - height = self._plotFrame.size[1] - vertices = numpy.array(((pixelPos[0], 0), - (pixelPos[0], height)), - dtype=numpy.float32) + if xCoord is None: # Horizontal line in data space + if marker['text'] is not None: + x = self._plotFrame.size[0] - \ + self._plotFrame.margins.right - pixelOffset + y = pixelPos[1] - pixelOffset + label = Text2D(marker['text'], x, y, + color=marker['color'], + bgColor=(1., 1., 1., 0.5), + align=RIGHT, valign=BOTTOM) + labels.append(label) - self._progBase.use() - gl.glUniform4f(self._progBase.uniforms['color'], - *marker['color']) + width = self._plotFrame.size[0] + lines = GLLines2D((0, width), (pixelPos[1], pixelPos[1]), + style=marker['linestyle'], + color=marker['color'], + width=marker['linewidth']) + lines.render(self.matScreenProj) + + else: # yCoord is None: vertical line in data space + if marker['text'] is not None: + x = pixelPos[0] + pixelOffset + y = self._plotFrame.margins.top + pixelOffset + label = Text2D(marker['text'], x, y, + color=marker['color'], + bgColor=(1., 1., 1., 0.5), + align=LEFT, valign=TOP) + labels.append(label) - gl.glEnableVertexAttribArray(posAttrib) - gl.glVertexAttribPointer(posAttrib, - 2, - gl.GL_FLOAT, - gl.GL_FALSE, - 0, vertices) - gl.glLineWidth(1) - gl.glDrawArrays(gl.GL_LINES, 0, len(vertices)) + height = self._plotFrame.size[1] + lines = GLLines2D((pixelPos[0], pixelPos[0]), (0, height), + style=marker['linestyle'], + color=marker['color'], + width=marker['linewidth']) + lines.render(self.matScreenProj) else: pixelPos = self.dataToPixel( @@ -820,13 +707,17 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): def _renderPlotAreaGL(self): plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:] - self._plotFrame.renderGrid() - gl.glScissor(self._plotFrame.margins.left, self._plotFrame.margins.bottom, plotWidth, plotHeight) gl.glEnable(gl.GL_SCISSOR_TEST) + if self._dataBackgroundColor != self._backgroundColor: + gl.glClearColor(*self._dataBackgroundColor) + gl.glClear(gl.GL_COLOR_BUFFER_BIT) + + self._plotFrame.renderGrid() + # Matrix trBounds = self._plotFrame.transformedDataRanges if trBounds.x[0] == trBounds.x[1] or \ @@ -853,32 +744,61 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): # Render Items gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1]) - self._progBase.use() - gl.glUniformMatrix4fv(self._progBase.uniforms['matrix'], 1, gl.GL_TRUE, - self.matScreenProj.astype(numpy.float32)) - gl.glUniform2i(self._progBase.uniforms['isLog'], False, False) - gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.) - for item in self._items.values(): if ((isXLog and numpy.min(item['x']) < FLOAT32_MINPOS) or (isYLog and numpy.min(item['y']) < FLOAT32_MINPOS)): # Ignore items <= 0. on log axes continue - closed = item['shape'] != 'polylines' - points = [self.dataToPixel(x, y, axis='left', check=False) - for (x, y) in zip(item['x'], item['y'])] - shape2D = Shape2D(points, - fill=item['fill'], - fillColor=item['color'], - stroke=True, - strokeColor=item['color'], - strokeClosed=closed) + if item['shape'] == 'hline': + width = self._plotFrame.size[0] + _, yPixel = self.dataToPixel( + None, item['y'], axis='left', check=False) + points = numpy.array(((0., yPixel), (width, yPixel)), + dtype=numpy.float32) - posAttrib = self._progBase.attributes['position'] - colorUnif = self._progBase.uniforms['color'] - hatchStepUnif = self._progBase.uniforms['hatchStep'] - shape2D.render(posAttrib, colorUnif, hatchStepUnif) + elif item['shape'] == 'vline': + xPixel, _ = self.dataToPixel( + item['x'], None, axis='left', check=False) + height = self._plotFrame.size[1] + points = numpy.array(((xPixel, 0), (xPixel, height)), + dtype=numpy.float32) + + else: + points = numpy.array([ + self.dataToPixel(x, y, axis='left', check=False) + for (x, y) in zip(item['x'], item['y'])]) + + # Draw the fill + if (item['fill'] is not None and + item['shape'] not in ('hline', 'vline')): + self._progBase.use() + gl.glUniformMatrix4fv( + self._progBase.uniforms['matrix'], 1, gl.GL_TRUE, + self.matScreenProj.astype(numpy.float32)) + gl.glUniform2i(self._progBase.uniforms['isLog'], False, False) + gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.) + + shape2D = FilledShape2D( + points, style=item['fill'], color=item['color']) + shape2D.render( + posAttrib=self._progBase.attributes['position'], + colorUnif=self._progBase.uniforms['color'], + hatchStepUnif=self._progBase.uniforms['hatchStep']) + + # Draw the stroke + if item['linestyle'] not in ('', ' ', None): + if item['shape'] != 'polylines': + # close the polyline + points = numpy.append(points, + numpy.atleast_2d(points[0]), axis=0) + + lines = GLLines2D(points[:, 0], points[:, 1], + style=item['linestyle'], + color=item['color'], + dash2ndColor=item['linebgcolor'], + width=item['linewidth']) + lines.render(self.matScreenProj) gl.glDisable(gl.GL_SCISSOR_TEST) @@ -1123,7 +1043,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): return legend, 'image' - def addItem(self, x, y, legend, shape, color, fill, overlay, z): + def addItem(self, x, y, legend, shape, color, fill, overlay, z, + linestyle, linewidth, linebgcolor): # TODO handle overlay if shape not in ('polygon', 'rectangle', 'line', 'vline', 'hline', 'polylines'): @@ -1154,7 +1075,10 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): 'color': colors.rgba(color), 'fill': 'hatch' if fill else None, 'x': x, - 'y': y + 'y': y, + 'linestyle': linestyle, + 'linewidth': linewidth, + 'linebgcolor': linebgcolor, } return legend, 'item' @@ -1166,10 +1090,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): if symbol is None: symbol = '+' - if linestyle != '-' or linewidth != 1: - _logger.warning( - 'OpenGL backend does not support marker line style and width.') - behaviors = set() if selectable: behaviors.add('selectable') @@ -1191,6 +1111,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): 'behaviors': behaviors, 'constraint': constraint if isConstraint else None, 'symbol': symbol, + 'linestyle': linestyle, + 'linewidth': linewidth, } return legend, 'marker' @@ -1441,37 +1363,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): if label: _logger.warning('Right axis label not implemented') - # Non orthogonal axes - - def setBaseVectors(self, x=(1., 0.), y=(0., 1.)): - """Set base vectors. - - Useful for non-orthogonal axes. - If an axis is in log scale, skew is applied to log transformed values. - - Base vector does not work well with log axes, to investi - """ - if x != (1., 0.) and y != (0., 1.): - if self._plotFrame.xAxis.isLog: - _logger.warning("setBaseVectors disables X axis logarithmic.") - self.setXAxisLogarithmic(False) - if self._plotFrame.yAxis.isLog: - _logger.warning("setBaseVectors disables Y axis logarithmic.") - self.setYAxisLogarithmic(False) - - if self.isKeepDataAspectRatio(): - _logger.warning("setBaseVectors disables keepDataAspectRatio.") - self.keepDataAspectRatio(False) - - self._plotFrame.baseVectors = x, y - - def getBaseVectors(self): - return self._plotFrame.baseVectors - - def isDefaultBaseVectors(self): - return self._plotFrame.baseVectors == \ - self._plotFrame.DEFAULT_BASE_VECTORS - # Graph limits def _setDataRanges(self, xlim=None, ylim=None, y2lim=None): @@ -1486,26 +1377,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): # Update axes range with a clipped range if too wide self._plotFrame.setDataRanges(xlim, ylim, y2lim) - if not self.isDefaultBaseVectors(): - # Update axes range with axes bounds in data coords - plotLeft, plotTop, plotWidth, plotHeight = \ - self.getPlotBoundsInPixels() - - self._plotFrame.xAxis.dataRange = sorted([ - self.pixelToData(x, y, axis='left', check=False)[0] - for (x, y) in ((plotLeft, plotTop + plotHeight), - (plotLeft + plotWidth, plotTop + plotHeight))]) - - self._plotFrame.yAxis.dataRange = sorted([ - self.pixelToData(x, y, axis='left', check=False)[1] - for (x, y) in ((plotLeft, plotTop + plotHeight), - (plotLeft, plotTop))]) - - self._plotFrame.y2Axis.dataRange = sorted([ - self.pixelToData(x, y, axis='right', check=False)[1] - for (x, y) in ((plotLeft + plotWidth, plotTop + plotHeight), - (plotLeft + plotWidth, plotTop))]) - def _ensureAspectRatio(self, keepDim=None): """Update plot bounds in order to keep aspect ratio. @@ -1619,11 +1490,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): _logger.warning( "KeepDataAspectRatio is ignored with log axes") - if flag and not self.isDefaultBaseVectors(): - _logger.warning( - "setXAxisLogarithmic ignored because baseVectors are set") - return - self._plotFrame.xAxis.isLog = flag def setYAxisLogarithmic(self, flag): @@ -1633,11 +1499,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): _logger.warning( "KeepDataAspectRatio is ignored with log axes") - if flag and not self.isDefaultBaseVectors(): - _logger.warning( - "setYAxisLogarithmic ignored because baseVectors are set") - return - self._plotFrame.yAxis.isLog = flag self._plotFrame.y2Axis.isLog = flag @@ -1658,9 +1519,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): if flag and (self._plotFrame.xAxis.isLog or self._plotFrame.yAxis.isLog): _logger.warning("KeepDataAspectRatio is ignored with log axes") - if flag and not self.isDefaultBaseVectors(): - _logger.warning( - "keepDataAspectRatio ignored because baseVectors are set") self._keepDataAspectRatio = flag @@ -1723,3 +1581,11 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): def setAxesDisplayed(self, displayed): BackendBase.BackendBase.setAxesDisplayed(self, displayed) self._plotFrame.displayed = displayed + + def setForegroundColors(self, foregroundColor, gridColor): + self._plotFrame.foregroundColor = foregroundColor + self._plotFrame.gridColor = gridColor + + def setBackgroundColors(self, backgroundColor, dataBackgroundColor): + self._backgroundColor = backgroundColor + self._dataBackgroundColor = dataBackgroundColor diff --git a/silx/gui/plot/backends/glutils/GLPlotCurve.py b/silx/gui/plot/backends/glutils/GLPlotCurve.py index 12b6bbe..5f8d652 100644 --- a/silx/gui/plot/backends/glutils/GLPlotCurve.py +++ b/silx/gui/plot/backends/glutils/GLPlotCurve.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2018 European Synchrotron Radiation Facility +# Copyright (c) 2014-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -42,7 +42,7 @@ import numpy from silx.math.combo import min_max from ...._glutils import gl -from ...._glutils import Program, vertexBuffer +from ...._glutils import Program, vertexBuffer, VertexBufferAttrib from .GLSupport import buildFillMaskIndices, mat4Identity, mat4Translate @@ -245,7 +245,7 @@ class _Fill2D(object): SOLID, DASHED, DASHDOT, DOTTED = '-', '--', '-.', ':' -class _Lines2D(object): +class GLLines2D(object): """Object rendering curve as a polyline :param xVboData: X coordinates VBO @@ -323,6 +323,7 @@ class _Lines2D(object): /* Dashes: [0, x], [y, z] Dash period: w */ uniform vec4 dash; + uniform vec4 dash2ndColor; varying float vDist; varying vec4 vColor; @@ -330,25 +331,52 @@ class _Lines2D(object): void main(void) { float dist = mod(vDist, dash.w); if ((dist > dash.x && dist < dash.y) || dist > dash.z) { - discard; + if (dash2ndColor.a == 0.) { + discard; // Discard full transparent bg color + } else { + gl_FragColor = dash2ndColor; + } + } else { + gl_FragColor = vColor; } - gl_FragColor = vColor; } """, attrib0='xPos') def __init__(self, xVboData=None, yVboData=None, colorVboData=None, distVboData=None, - style=SOLID, color=(0., 0., 0., 1.), - width=1, dashPeriod=20, drawMode=None, + style=SOLID, color=(0., 0., 0., 1.), dash2ndColor=None, + width=1, dashPeriod=10., drawMode=None, offset=(0., 0.)): + if (xVboData is not None and + not isinstance(xVboData, VertexBufferAttrib)): + xVboData = numpy.array(xVboData, copy=False, dtype=numpy.float32) self.xVboData = xVboData + + if (yVboData is not None and + not isinstance(yVboData, VertexBufferAttrib)): + yVboData = numpy.array(yVboData, copy=False, dtype=numpy.float32) self.yVboData = yVboData + + # Compute distances if not given while providing numpy array coordinates + if (isinstance(self.xVboData, numpy.ndarray) and + isinstance(self.yVboData, numpy.ndarray) and + distVboData is None): + distVboData = distancesFromArrays(self.xVboData, self.yVboData) + + if (distVboData is not None and + not isinstance(distVboData, VertexBufferAttrib)): + distVboData = numpy.array( + distVboData, copy=False, dtype=numpy.float32) self.distVboData = distVboData + + if colorVboData is not None: + assert isinstance(colorVboData, VertexBufferAttrib) self.colorVboData = colorVboData self.useColorVboData = colorVboData is not None self.color = color + self.dash2ndColor = dash2ndColor self.width = width self._style = None self.style = style @@ -396,29 +424,46 @@ class _Lines2D(object): gl.glUniform2f(program.uniforms['halfViewportSize'], 0.5 * viewWidth, 0.5 * viewHeight) + dashPeriod = self.dashPeriod * self.width if self.style == DOTTED: - dash = (0.1 * self.dashPeriod, - 0.6 * self.dashPeriod, - 0.7 * self.dashPeriod, - self.dashPeriod) + dash = (0.2 * dashPeriod, + 0.5 * dashPeriod, + 0.7 * dashPeriod, + dashPeriod) elif self.style == DASHDOT: - dash = (0.3 * self.dashPeriod, - 0.5 * self.dashPeriod, - 0.6 * self.dashPeriod, - self.dashPeriod) + dash = (0.3 * dashPeriod, + 0.5 * dashPeriod, + 0.6 * dashPeriod, + dashPeriod) else: - dash = (0.5 * self.dashPeriod, - self.dashPeriod, - self.dashPeriod, - self.dashPeriod) + dash = (0.5 * dashPeriod, + dashPeriod, + dashPeriod, + dashPeriod) gl.glUniform4f(program.uniforms['dash'], *dash) + if self.dash2ndColor is None: + # Use fully transparent color which gets discarded in shader + dash2ndColor = (0., 0., 0., 0.) + else: + dash2ndColor = self.dash2ndColor + gl.glUniform4f(program.uniforms['dash2ndColor'], *dash2ndColor) + distAttrib = program.attributes['distance'] gl.glEnableVertexAttribArray(distAttrib) - self.distVboData.setVertexAttrib(distAttrib) + if isinstance(self.distVboData, VertexBufferAttrib): + self.distVboData.setVertexAttrib(distAttrib) + else: + gl.glVertexAttribPointer(distAttrib, + 1, + gl.GL_FLOAT, + False, + 0, + self.distVboData) - gl.glEnable(gl.GL_LINE_SMOOTH) + if self.width != 1: + gl.glEnable(gl.GL_LINE_SMOOTH) matrix = numpy.dot(matrix, mat4Translate(*self.offset)).astype(numpy.float32) @@ -435,11 +480,27 @@ class _Lines2D(object): xPosAttrib = program.attributes['xPos'] gl.glEnableVertexAttribArray(xPosAttrib) - self.xVboData.setVertexAttrib(xPosAttrib) + if isinstance(self.xVboData, VertexBufferAttrib): + self.xVboData.setVertexAttrib(xPosAttrib) + else: + gl.glVertexAttribPointer(xPosAttrib, + 1, + gl.GL_FLOAT, + False, + 0, + self.xVboData) yPosAttrib = program.attributes['yPos'] gl.glEnableVertexAttribArray(yPosAttrib) - self.yVboData.setVertexAttrib(yPosAttrib) + if isinstance(self.yVboData, VertexBufferAttrib): + self.yVboData.setVertexAttrib(yPosAttrib) + else: + gl.glVertexAttribPointer(yPosAttrib, + 1, + gl.GL_FLOAT, + False, + 0, + self.yVboData) gl.glLineWidth(self.width) gl.glDrawArrays(self._drawMode, 0, self.xVboData.size) @@ -447,7 +508,7 @@ class _Lines2D(object): gl.glDisable(gl.GL_LINE_SMOOTH) -def _distancesFromArrays(xData, yData): +def distancesFromArrays(xData, yData): """Returns distances between each points :param numpy.ndarray xData: X coordinate of points @@ -711,7 +772,7 @@ class _ErrorBars(object): This is using its own VBO as opposed to fill/points/lines. There is no picking on error bars. - It uses 2 vertices per error bars and uses :class:`_Lines2D` to + It uses 2 vertices per error bars and uses :class:`GLLines2D` to render error bars and :class:`_Points2D` to render the ends. :param numpy.ndarray xData: X coordinates of the data. @@ -753,7 +814,7 @@ class _ErrorBars(object): self._xData, self._yData = None, None self._xError, self._yError = None, None - self._lines = _Lines2D( + self._lines = GLLines2D( None, None, color=color, drawMode=gl.GL_LINES, offset=offset) self._xErrPoints = _Points2D( None, None, color=color, marker=V_LINE, offset=offset) @@ -957,7 +1018,7 @@ class GLPlotCurve2D(object): self.xMin, self.yMin, offset=self.offset) - self.lines = _Lines2D() + self.lines = GLLines2D() self.lines.style = lineStyle self.lines.color = lineColor self.lines.width = lineWidth @@ -999,7 +1060,7 @@ class GLPlotCurve2D(object): @classmethod def init(cls): """OpenGL context initialization""" - _Lines2D.init() + GLLines2D.init() _Points2D.init() def prepare(self): @@ -1007,7 +1068,7 @@ class GLPlotCurve2D(object): if self.xVboData is None: xAttrib, yAttrib, cAttrib, dAttrib = None, None, None, None if self.lineStyle in (DASHED, DASHDOT, DOTTED): - dists = _distancesFromArrays(self.xData, self.yData) + dists = distancesFromArrays(self.xData, self.yData) if self.colorData is None: xAttrib, yAttrib, dAttrib = vertexBuffer( (self.xData, self.yData, dists)) diff --git a/silx/gui/plot/backends/glutils/GLPlotFrame.py b/silx/gui/plot/backends/glutils/GLPlotFrame.py index 4ad1547..43f6e10 100644 --- a/silx/gui/plot/backends/glutils/GLPlotFrame.py +++ b/silx/gui/plot/backends/glutils/GLPlotFrame.py @@ -63,6 +63,7 @@ class PlotAxis(object): def __init__(self, plot, tickLength=(0., 0.), + foregroundColor=(0., 0., 0., 1.0), labelAlign=CENTER, labelVAlign=CENTER, titleAlign=CENTER, titleVAlign=CENTER, titleRotate=0, titleOffset=(0., 0.)): @@ -78,6 +79,7 @@ class PlotAxis(object): self._title = '' self._tickLength = tickLength + self._foregroundColor = foregroundColor self._labelAlign = labelAlign self._labelVAlign = labelVAlign self._titleAlign = titleAlign @@ -169,6 +171,20 @@ class PlotAxis(object): plot._dirty() @property + def foregroundColor(self): + """Color used for frame and labels""" + return self._foregroundColor + + @foregroundColor.setter + def foregroundColor(self, color): + """Color used for frame and labels""" + assert len(color) == 4, \ + "foregroundColor must have length 4, got {}".format(len(self._foregroundColor)) + if self._foregroundColor != color: + self._foregroundColor = color + self._dirtyTicks() + + @property def ticks(self): """Ticks as tuples: ((x, y) in display, dataPos, textLabel).""" if self._ticks is None: @@ -192,6 +208,7 @@ class PlotAxis(object): tickScale = 1. label = Text2D(text=text, + color=self._foregroundColor, x=xPixel - xTickLength, y=yPixel - yTickLength, align=self._labelAlign, @@ -223,6 +240,7 @@ class PlotAxis(object): # yOffset -= 3 * yTickLength axisTitle = Text2D(text=self.title, + color=self._foregroundColor, x=xAxisCenter + xOffset, y=yAxisCenter + yOffset, align=self._titleAlign, @@ -373,15 +391,21 @@ class GLPlotFrame(object): # Margins used when plot frame is not displayed _NoDisplayMargins = _Margins(0, 0, 0, 0) - def __init__(self, margins): + def __init__(self, margins, foregroundColor, gridColor): """ :param margins: The margins around plot area for axis and labels. :type margins: dict with 'left', 'right', 'top', 'bottom' keys and values as ints. + :param foregroundColor: color used for the frame and labels. + :type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0 + :param gridColor: color used for grid lines. + :type gridColor: tuple RGBA with RGBA values ranging from 0.0 to 1.0 """ self._renderResources = None self._margins = self._Margins(**margins) + self._foregroundColor = foregroundColor + self._gridColor = gridColor self.axes = [] # List of PlotAxis to be updated by subclasses @@ -401,6 +425,36 @@ class GLPlotFrame(object): GRID_ALL_TICKS = (GRID_MAIN_TICKS + GRID_SUB_TICKS) @property + def foregroundColor(self): + """Color used for frame and labels""" + return self._foregroundColor + + @foregroundColor.setter + def foregroundColor(self, color): + """Color used for frame and labels""" + assert len(color) == 4, \ + "foregroundColor must have length 4, got {}".format(len(self._foregroundColor)) + if self._foregroundColor != color: + self._foregroundColor = color + for axis in self.axes: + axis.foregroundColor = color + self._dirty() + + @property + def gridColor(self): + """Color used for frame and labels""" + return self._gridColor + + @gridColor.setter + def gridColor(self, color): + """Color used for frame and labels""" + assert len(color) == 4, \ + "gridColor must have length 4, got {}".format(len(self._gridColor)) + if self._gridColor != color: + self._gridColor = color + self._dirty() + + @property def displayed(self): """Whether axes and their labels are displayed or not (bool)""" return self._displayed @@ -522,6 +576,7 @@ class GLPlotFrame(object): self.margins.right) // 2 yTitle = self.margins.top - self._TICK_LENGTH_IN_PIXELS labels.append(Text2D(text=self.title, + color=self._foregroundColor, x=xTitle, y=yTitle, align=CENTER, @@ -556,7 +611,7 @@ class GLPlotFrame(object): gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matProj.astype(numpy.float32)) - gl.glUniform4f(prog.uniforms['color'], 0., 0., 0., 1.) + gl.glUniform4f(prog.uniforms['color'], *self._foregroundColor) gl.glUniform1f(prog.uniforms['tickFactor'], 0.) gl.glEnableVertexAttribArray(prog.attributes['position']) @@ -590,7 +645,7 @@ class GLPlotFrame(object): gl.glLineWidth(self._LINE_WIDTH) gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matProj.astype(numpy.float32)) - gl.glUniform4f(prog.uniforms['color'], 0.7, 0.7, 0.7, 1.) + gl.glUniform4f(prog.uniforms['color'], *self._gridColor) gl.glUniform1f(prog.uniforms['tickFactor'], 0.) # 1/2.) # 1/tickLen gl.glEnableVertexAttribArray(prog.attributes['position']) @@ -606,15 +661,21 @@ class GLPlotFrame(object): # GLPlotFrame2D ############################################################### class GLPlotFrame2D(GLPlotFrame): - def __init__(self, margins): + def __init__(self, margins, foregroundColor, gridColor): """ :param margins: The margins around plot area for axis and labels. :type margins: dict with 'left', 'right', 'top', 'bottom' keys and values as ints. + :param foregroundColor: color used for the frame and labels. + :type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0 + :param gridColor: color used for grid lines. + :type gridColor: tuple RGBA with RGBA values ranging from 0.0 to 1.0 + """ - super(GLPlotFrame2D, self).__init__(margins) + super(GLPlotFrame2D, self).__init__(margins, foregroundColor, gridColor) self.axes.append(PlotAxis(self, tickLength=(0., -5.), + foregroundColor=self._foregroundColor, labelAlign=CENTER, labelVAlign=TOP, titleAlign=CENTER, titleVAlign=TOP, titleRotate=0, @@ -624,6 +685,7 @@ class GLPlotFrame2D(GLPlotFrame): self.axes.append(PlotAxis(self, tickLength=(5., 0.), + foregroundColor=self._foregroundColor, labelAlign=RIGHT, labelVAlign=CENTER, titleAlign=CENTER, titleVAlign=BOTTOM, titleRotate=ROTATE_270, @@ -632,6 +694,7 @@ class GLPlotFrame2D(GLPlotFrame): self._y2Axis = PlotAxis(self, tickLength=(-5., 0.), + foregroundColor=self._foregroundColor, labelAlign=LEFT, labelVAlign=CENTER, titleAlign=CENTER, titleVAlign=TOP, titleRotate=ROTATE_270, @@ -825,23 +888,6 @@ class GLPlotFrame2D(GLPlotFrame): _logger.info('yMax: warning log10(%f)', y2Max) y2Max = 0. - # Non-orthogonal axes - if self.baseVectors != self.DEFAULT_BASE_VECTORS: - (xx, xy), (yx, yy) = self.baseVectors - skew_mat = numpy.array(((xx, yx), (xy, yy))) - - corners = [(xMin, yMin), (xMin, yMax), - (xMax, yMin), (xMax, yMax), - (xMin, y2Min), (xMin, y2Max), - (xMax, y2Min), (xMax, y2Max)] - - corners = numpy.array( - [numpy.dot(skew_mat, corner) for corner in corners], - dtype=numpy.float32) - xMin, xMax = corners[:, 0].min(), corners[:, 0].max() - yMin, yMax = corners[0:4, 1].min(), corners[0:4, 1].max() - y2Min, y2Max = corners[4:, 1].min(), corners[4:, 1].max() - self._transformedDataRanges = self._DataRanges( (xMin, xMax), (yMin, yMax), (y2Min, y2Max)) @@ -861,16 +907,6 @@ class GLPlotFrame2D(GLPlotFrame): mat = mat4Ortho(xMin, xMax, yMax, yMin, 1, -1) else: mat = mat4Ortho(xMin, xMax, yMin, yMax, 1, -1) - - # Non-orthogonal axes - if self.baseVectors != self.DEFAULT_BASE_VECTORS: - (xx, xy), (yx, yy) = self.baseVectors - mat = numpy.dot(mat, numpy.array(( - (xx, yx, 0., 0.), - (xy, yy, 0., 0.), - (0., 0., 1., 0.), - (0., 0., 0., 1.)), dtype=numpy.float64)) - self._transformedDataProjMat = mat return self._transformedDataProjMat @@ -890,16 +926,6 @@ class GLPlotFrame2D(GLPlotFrame): mat = mat4Ortho(xMin, xMax, y2Max, y2Min, 1, -1) else: mat = mat4Ortho(xMin, xMax, y2Min, y2Max, 1, -1) - - # Non-orthogonal axes - if self.baseVectors != self.DEFAULT_BASE_VECTORS: - (xx, xy), (yx, yy) = self.baseVectors - mat = numpy.dot(mat, numpy.matrix(( - (xx, yx, 0., 0.), - (xy, yy, 0., 0.), - (0., 0., 1., 0.), - (0., 0., 0., 1.)), dtype=numpy.float64)) - self._transformedDataY2ProjMat = mat return self._transformedDataY2ProjMat @@ -1114,3 +1140,17 @@ class GLPlotFrame2D(GLPlotFrame): vertices = numpy.append(vertices, extraVertices, axis=0) self._renderResources = (vertices, gridVertices, labels) + + @property + def foregroundColor(self): + """Color used for frame and labels""" + return self._foregroundColor + + @foregroundColor.setter + def foregroundColor(self, color): + """Color used for frame and labels""" + assert len(color) == 4, \ + "foregroundColor must have length 4, got {}".format(len(self._foregroundColor)) + if self._foregroundColor != color: + self._y2Axis.foregroundColor = color + GLPlotFrame.foregroundColor.fset(self, color) # call parent property diff --git a/silx/gui/plot/backends/glutils/GLSupport.py b/silx/gui/plot/backends/glutils/GLSupport.py index 18c5eb7..da6dffa 100644 --- a/silx/gui/plot/backends/glutils/GLSupport.py +++ b/silx/gui/plot/backends/glutils/GLSupport.py @@ -60,16 +60,12 @@ def buildFillMaskIndices(nIndices, dtype=None): return indices -class Shape2D(object): +class FilledShape2D(object): _NO_HATCH = 0 _HATCH_STEP = 20 - def __init__(self, points, fill='solid', stroke=True, - fillColor=(0., 0., 0., 1.), strokeColor=(0., 0., 0., 1.), - strokeClosed=True): + def __init__(self, points, style='solid', color=(0., 0., 0., 1.)): self.vertices = numpy.array(points, dtype=numpy.float32, copy=False) - self.strokeClosed = strokeClosed - self._indices = buildFillMaskIndices(len(self.vertices)) tVertex = numpy.transpose(self.vertices) @@ -81,28 +77,16 @@ class Shape2D(object): self._xMin, self._xMax = xMin, xMax self._yMin, self._yMax = yMin, yMax - self.fill = fill - self.fillColor = fillColor - self.stroke = stroke - self.strokeColor = strokeColor - - @property - def xMin(self): - return self._xMin - - @property - def xMax(self): - return self._xMax - - @property - def yMin(self): - return self._yMin + self.style = style + self.color = color - @property - def yMax(self): - return self._yMax + def render(self, posAttrib, colorUnif, hatchStepUnif): + assert self.style in ('hatch', 'solid') + gl.glUniform4f(colorUnif, *self.color) + step = self._HATCH_STEP if self.style == 'hatch' else self._NO_HATCH + gl.glUniform1i(hatchStepUnif, step) - def prepareFillMask(self, posAttrib): + # Prepare fill mask gl.glEnableVertexAttribArray(posAttrib) gl.glVertexAttribPointer(posAttrib, 2, @@ -126,9 +110,6 @@ class Shape2D(object): gl.glColorMask(gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE) gl.glDepthMask(gl.GL_TRUE) - def renderFill(self, posAttrib): - self.prepareFillMask(posAttrib) - gl.glVertexAttribPointer(posAttrib, 2, gl.GL_FLOAT, @@ -138,30 +119,6 @@ class Shape2D(object): gl.glDisable(gl.GL_STENCIL_TEST) - def renderStroke(self, posAttrib): - gl.glEnableVertexAttribArray(posAttrib) - gl.glVertexAttribPointer(posAttrib, - 2, - gl.GL_FLOAT, - gl.GL_FALSE, - 0, self.vertices) - gl.glLineWidth(1) - drawMode = gl.GL_LINE_LOOP if self.strokeClosed else gl.GL_LINE_STRIP - gl.glDrawArrays(drawMode, 0, len(self.vertices)) - - def render(self, posAttrib, colorUnif, hatchStepUnif): - assert self.fill in ['hatch', 'solid', None] - if self.fill is not None: - gl.glUniform4f(colorUnif, *self.fillColor) - step = self._HATCH_STEP if self.fill == 'hatch' else self._NO_HATCH - gl.glUniform1i(hatchStepUnif, step) - self.renderFill(posAttrib) - - if self.stroke: - gl.glUniform4f(colorUnif, *self.strokeColor) - gl.glUniform1i(hatchStepUnif, self._NO_HATCH) - self.renderStroke(posAttrib) - # matrix ###################################################################### diff --git a/silx/gui/plot/items/__init__.py b/silx/gui/plot/items/__init__.py index e7957ac..f829f78 100644 --- a/silx/gui/plot/items/__init__.py +++ b/silx/gui/plot/items/__init__.py @@ -36,7 +36,7 @@ from .core import (Item, LabelsMixIn, DraggableMixIn, ColormapMixIn, # noqa SymbolMixIn, ColorMixIn, YAxisMixIn, FillMixIn, # noqa AlphaMixIn, LineMixIn, ItemChangedType) # noqa from .complex import ImageComplexData # noqa -from .curve import Curve # noqa +from .curve import Curve, CurveStyle # noqa from .histogram import Histogram # noqa from .image import ImageBase, ImageData, ImageRgba, MaskImageData # noqa from .shape import Shape # noqa diff --git a/silx/gui/plot/items/axis.py b/silx/gui/plot/items/axis.py index 3d9fe14..8ea5c7a 100644 --- a/silx/gui/plot/items/axis.py +++ b/silx/gui/plot/items/axis.py @@ -27,16 +27,16 @@ __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "06/12/2017" +__date__ = "22/11/2018" import datetime as dt +import enum import logging import dateutil.tz from ... import qt -from silx.third_party import enum _logger = logging.getLogger(__name__) @@ -448,6 +448,8 @@ class YAxis(Axis): False for Y axis going from bottom to top """ flag = bool(flag) + if self.isInverted() == flag: + return self._getBackend().setYAxisInverted(flag) self._getPlot()._setDirtyPlot() self.sigInvertedChanged.emit(flag) diff --git a/silx/gui/plot/items/complex.py b/silx/gui/plot/items/complex.py index 535b0a9..7fffd77 100644 --- a/silx/gui/plot/items/complex.py +++ b/silx/gui/plot/items/complex.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -33,9 +33,9 @@ __date__ = "14/06/2018" import logging -import numpy +import enum -from silx.third_party import enum +import numpy from ...colors import Colormap from .core import ColormapMixIn, ItemChangedType @@ -137,7 +137,6 @@ class ImageComplexData(ImageBase, ColormapMixIn): name='hsv', vmin=-numpy.pi, vmax=numpy.pi) - phaseColormap.setEditable(False) self._colormaps = { # Default colormaps for all modes self.Mode.ABSOLUTE: colormap, @@ -180,7 +179,6 @@ class ImageComplexData(ImageBase, ColormapMixIn): colormap=colormap, alpha=self.getAlpha()) - def setVisualizationMode(self, mode): """Set the visualization mode to use. diff --git a/silx/gui/plot/items/core.py b/silx/gui/plot/items/core.py index e000751..bf3b719 100644 --- a/silx/gui/plot/items/core.py +++ b/silx/gui/plot/items/core.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -27,20 +27,23 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "14/06/2018" +__date__ = "29/01/2019" import collections from copy import deepcopy import logging +import enum import warnings import weakref + import numpy -from silx.third_party import six, enum +import six from ... import qt from ... import colors from ...colors import Colormap +from silx import config _logger = logging.getLogger(__name__) @@ -82,6 +85,9 @@ class ItemChangedType(enum.Enum): COLOR = 'colorChanged' """Item's color changed flag.""" + LINE_BG_COLOR = 'lineBgColorChanged' + """Item's line background color changed flag.""" + YAXIS = 'yAxisChanged' """Item's Y axis binding changed flag.""" @@ -411,10 +417,12 @@ class ColormapMixIn(ItemMixInBase): return self._colormap def setColormap(self, colormap): - """Set the colormap of this image + """Set the colormap of this item :param silx.gui.colors.Colormap colormap: colormap description """ + if self._colormap is colormap: + return if isinstance(colormap, dict): colormap = Colormap._fromDict(colormap) @@ -433,10 +441,10 @@ class ColormapMixIn(ItemMixInBase): class SymbolMixIn(ItemMixInBase): """Mix-in class for items with symbol type""" - _DEFAULT_SYMBOL = '' + _DEFAULT_SYMBOL = None """Default marker of the item""" - _DEFAULT_SYMBOL_SIZE = 6.0 + _DEFAULT_SYMBOL_SIZE = config.DEFAULT_PLOT_SYMBOL_SIZE """Default marker size of the item""" _SUPPORTED_SYMBOLS = collections.OrderedDict(( @@ -451,8 +459,15 @@ class SymbolMixIn(ItemMixInBase): """Dict of supported symbols""" def __init__(self): - self._symbol = self._DEFAULT_SYMBOL - self._symbol_size = self._DEFAULT_SYMBOL_SIZE + if self._DEFAULT_SYMBOL is None: # Use default from config + self._symbol = config.DEFAULT_PLOT_SYMBOL + else: + self._symbol = self._DEFAULT_SYMBOL + + if self._DEFAULT_SYMBOL_SIZE is None: # Use default from config + self._symbol_size = config.DEFAULT_PLOT_SYMBOL_SIZE + else: + self._symbol_size = self._DEFAULT_SYMBOL_SIZE @classmethod def getSupportedSymbols(cls): @@ -892,14 +907,14 @@ class Points(Item, SymbolMixIn, AlphaMixIn): # use the getData class method because instance method can be # overloaded to return additional arrays data = Points.getData(self, copy=False, - displayed=True) + displayed=True) if len(data) == 5: # hack to avoid duplicating caching mechanism in Scatter # (happens when cached data is used, caching done using # Scatter._logFilterData) - x, y, xerror, yerror = data[0], data[1], data[3], data[4] + x, y, _xerror, _yerror = data[0], data[1], data[3], data[4] else: - x, y, xerror, yerror = data + x, y, _xerror, _yerror = data self._boundsCache[(xPositive, yPositive)] = ( numpy.nanmin(x), diff --git a/silx/gui/plot/items/curve.py b/silx/gui/plot/items/curve.py index 80d9dea..79def55 100644 --- a/silx/gui/plot/items/curve.py +++ b/silx/gui/plot/items/curve.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -31,9 +31,10 @@ __date__ = "24/04/2018" import logging + import numpy +import six -from silx.third_party import six from ....utils.deprecation import deprecated from ... import colors from .core import (Points, LabelsMixIn, ColorMixIn, YAxisMixIn, diff --git a/silx/gui/plot/items/histogram.py b/silx/gui/plot/items/histogram.py index 389e8a6..a1d6586 100644 --- a/silx/gui/plot/items/histogram.py +++ b/silx/gui/plot/items/histogram.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -197,13 +197,13 @@ class Histogram(Item, AlphaMixIn, ColorMixIn, FillMixIn, values[clipped_values] = numpy.nan - if xPositive or yPositive: + if yPositive: return (numpy.nanmin(edges), numpy.nanmax(edges), numpy.nanmin(values), numpy.nanmax(values)) - else: # No log scale, include 0 in bounds + else: # No log scale on y axis, include 0 in bounds return (numpy.nanmin(edges), numpy.nanmax(edges), min(0, numpy.nanmin(values)), diff --git a/silx/gui/plot/items/roi.py b/silx/gui/plot/items/roi.py index f55ef91..0169439 100644 --- a/silx/gui/plot/items/roi.py +++ b/silx/gui/plot/items/roi.py @@ -65,7 +65,7 @@ class RegionOfInterest(qt.QObject): # Avoid circular dependancy from ..tools import roi as roi_tools assert parent is None or isinstance(parent, roi_tools.RegionOfInterestManager) - super(RegionOfInterest, self).__init__(parent) + qt.QObject.__init__(self, parent) self._color = rgba('red') self._items = WeakList() self._editAnchors = WeakList() @@ -108,7 +108,7 @@ class RegionOfInterest(qt.QObject): return qt.QColor.fromRgbF(*self._color) def _getAnchorColor(self, color): - """Returns the anchor color from the base ROI color + """Returns the anchor color from the base ROI color :param Union[numpy.array,Tuple,List]: color :rtype: Union[numpy.array,Tuple,List] @@ -209,7 +209,7 @@ class RegionOfInterest(qt.QObject): def setFirstShapePoints(self, points): """"Initialize the ROI using the points from the first interaction. - This interaction is constains by the plot API and only supports few + This interaction is constrained by the plot API and only supports few shapes. """ points = self._createControlPointsFromFirstShape(points) @@ -410,6 +410,13 @@ class RegionOfInterest(qt.QObject): plot._remove(item) self._labelItem = None + def _updated(self, event=None, checkVisibility=True): + """Implement Item mix-in update method by updating the plot items + + See :class:`~silx.gui.plot.items.Item._updated` + """ + self._createPlotItems() + def __str__(self): """Returns parameters of the ROI as a string.""" points = self._getControlPoints() @@ -417,7 +424,7 @@ class RegionOfInterest(qt.QObject): return "%s(%s)" % (self.__class__.__name__, params) -class PointROI(RegionOfInterest): +class PointROI(RegionOfInterest, items.SymbolMixIn): """A ROI identifying a point in a 2D plot.""" _kind = "Point" @@ -426,6 +433,10 @@ class PointROI(RegionOfInterest): _plotShape = "point" """Plot shape which is used for the first interaction""" + def __init__(self, parent=None): + items.SymbolMixIn.__init__(self) + RegionOfInterest.__init__(self, parent=parent) + def getPosition(self): """Returns the position of this ROI @@ -458,6 +469,8 @@ class PointROI(RegionOfInterest): marker.setPosition(points[0][0], points[0][1]) marker.setText(self.getLabel()) marker.setColor(rgba(self.getColor())) + marker.setSymbol(self.getSymbol()) + marker.setSymbolSize(self.getSymbolSize()) marker._setDraggable(False) return [marker] @@ -466,6 +479,8 @@ class PointROI(RegionOfInterest): marker.setPosition(points[0][0], points[0][1]) marker.setText(self.getLabel()) marker._setDraggable(self.isEditable()) + marker.setSymbol(self.getSymbol()) + marker.setSymbolSize(self.getSymbolSize()) return [marker] def __str__(self): @@ -474,7 +489,7 @@ class PointROI(RegionOfInterest): return "%s(%s)" % (self.__class__.__name__, params) -class LineROI(RegionOfInterest): +class LineROI(RegionOfInterest, items.LineMixIn): """A ROI identifying a line in a 2D plot. This ROI provides 1 anchor for each boundary of the line, plus an center @@ -487,6 +502,10 @@ class LineROI(RegionOfInterest): _plotShape = "line" """Plot shape which is used for the first interaction""" + def __init__(self, parent=None): + items.LineMixIn.__init__(self) + RegionOfInterest.__init__(self, parent=parent) + def _createControlPointsFromFirstShape(self, points): center = numpy.mean(points, axis=0) controlPoints = numpy.array([points[0], points[1], center]) @@ -535,6 +554,8 @@ class LineROI(RegionOfInterest): item.setColor(rgba(self.getColor())) item.setFill(False) item.setOverlay(True) + item.setLineStyle(self.getLineStyle()) + item.setLineWidth(self.getLineWidth()) return [item] def _createAnchorItems(self, points): @@ -582,7 +603,7 @@ class LineROI(RegionOfInterest): return "%s(%s)" % (self.__class__.__name__, params) -class HorizontalLineROI(RegionOfInterest): +class HorizontalLineROI(RegionOfInterest, items.LineMixIn): """A ROI identifying an horizontal line in a 2D plot.""" _kind = "HLine" @@ -591,6 +612,10 @@ class HorizontalLineROI(RegionOfInterest): _plotShape = "hline" """Plot shape which is used for the first interaction""" + def __init__(self, parent=None): + items.LineMixIn.__init__(self) + RegionOfInterest.__init__(self, parent=parent) + def _createControlPointsFromFirstShape(self, points): points = numpy.array([(float('nan'), points[0, 1])], dtype=numpy.float64) @@ -636,6 +661,8 @@ class HorizontalLineROI(RegionOfInterest): marker.setText(self.getLabel()) marker.setColor(rgba(self.getColor())) marker._setDraggable(False) + marker.setLineWidth(self.getLineWidth()) + marker.setLineStyle(self.getLineStyle()) return [marker] def _createAnchorItems(self, points): @@ -643,6 +670,8 @@ class HorizontalLineROI(RegionOfInterest): marker.setPosition(points[0][0], points[0][1]) marker.setText(self.getLabel()) marker._setDraggable(self.isEditable()) + marker.setLineWidth(self.getLineWidth()) + marker.setLineStyle(self.getLineStyle()) return [marker] def __str__(self): @@ -651,7 +680,7 @@ class HorizontalLineROI(RegionOfInterest): return "%s(%s)" % (self.__class__.__name__, params) -class VerticalLineROI(RegionOfInterest): +class VerticalLineROI(RegionOfInterest, items.LineMixIn): """A ROI identifying a vertical line in a 2D plot.""" _kind = "VLine" @@ -660,6 +689,10 @@ class VerticalLineROI(RegionOfInterest): _plotShape = "vline" """Plot shape which is used for the first interaction""" + def __init__(self, parent=None): + items.LineMixIn.__init__(self) + RegionOfInterest.__init__(self, parent=parent) + def _createControlPointsFromFirstShape(self, points): points = numpy.array([(points[0, 0], float('nan'))], dtype=numpy.float64) @@ -705,6 +738,8 @@ class VerticalLineROI(RegionOfInterest): marker.setText(self.getLabel()) marker.setColor(rgba(self.getColor())) marker._setDraggable(False) + marker.setLineWidth(self.getLineWidth()) + marker.setLineStyle(self.getLineStyle()) return [marker] def _createAnchorItems(self, points): @@ -712,6 +747,8 @@ class VerticalLineROI(RegionOfInterest): marker.setPosition(points[0][0], points[0][1]) marker.setText(self.getLabel()) marker._setDraggable(self.isEditable()) + marker.setLineWidth(self.getLineWidth()) + marker.setLineStyle(self.getLineStyle()) return [marker] def __str__(self): @@ -720,7 +757,7 @@ class VerticalLineROI(RegionOfInterest): return "%s(%s)" % (self.__class__.__name__, params) -class RectangleROI(RegionOfInterest): +class RectangleROI(RegionOfInterest, items.LineMixIn): """A ROI identifying a rectangle in a 2D plot. This ROI provides 1 anchor for each corner, plus an anchor in the @@ -733,6 +770,10 @@ class RectangleROI(RegionOfInterest): _plotShape = "rectangle" """Plot shape which is used for the first interaction""" + def __init__(self, parent=None): + items.LineMixIn.__init__(self) + RegionOfInterest.__init__(self, parent=parent) + def _createControlPointsFromFirstShape(self, points): point0 = points[0] point1 = points[1] @@ -838,6 +879,8 @@ class RectangleROI(RegionOfInterest): item.setColor(rgba(self.getColor())) item.setFill(False) item.setOverlay(True) + item.setLineStyle(self.getLineStyle()) + item.setLineWidth(self.getLineWidth()) return [item] def _createAnchorItems(self, points): @@ -894,7 +937,7 @@ class RectangleROI(RegionOfInterest): return "%s(%s)" % (self.__class__.__name__, params) -class PolygonROI(RegionOfInterest): +class PolygonROI(RegionOfInterest, items.LineMixIn): """A ROI identifying a closed polygon in a 2D plot. This ROI provides 1 anchor for each point of the polygon. @@ -906,6 +949,10 @@ class PolygonROI(RegionOfInterest): _plotShape = "polygon" """Plot shape which is used for the first interaction""" + def __init__(self, parent=None): + items.LineMixIn.__init__(self) + RegionOfInterest.__init__(self, parent=parent) + def getPoints(self): """Returns the list of the points of this polygon. @@ -948,6 +995,8 @@ class PolygonROI(RegionOfInterest): item.setColor(rgba(self.getColor())) item.setFill(False) item.setOverlay(True) + item.setLineStyle(self.getLineStyle()) + item.setLineWidth(self.getLineWidth()) return [item] def _createAnchorItems(self, points): @@ -967,7 +1016,7 @@ class PolygonROI(RegionOfInterest): return "%s(%s)" % (self.__class__.__name__, params) -class ArcROI(RegionOfInterest): +class ArcROI(RegionOfInterest, items.LineMixIn): """A ROI identifying an arc of a circle with a width. This ROI provides 3 anchors to control the curvature, 1 anchor to control @@ -986,6 +1035,7 @@ class ArcROI(RegionOfInterest): 'startAngle', 'endAngle']) def __init__(self, parent=None): + items.LineMixIn.__init__(self) RegionOfInterest.__init__(self, parent=parent) self._geometry = None @@ -1357,6 +1407,8 @@ class ArcROI(RegionOfInterest): item.setColor(rgba(self.getColor())) item.setFill(False) item.setOverlay(True) + item.setLineStyle(self.getLineStyle()) + item.setLineWidth(self.getLineWidth()) return [item] def _createAnchorItems(self, points): diff --git a/silx/gui/plot/items/scatter.py b/silx/gui/plot/items/scatter.py index acc74b4..707dd3d 100644 --- a/silx/gui/plot/items/scatter.py +++ b/silx/gui/plot/items/scatter.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -46,9 +46,6 @@ class Scatter(Points, ColormapMixIn): _DEFAULT_SELECTABLE = True """Default selectable state for scatter plots""" - _DEFAULT_SYMBOL = 'o' - """Default symbol of the scatter plots""" - def __init__(self): Points.__init__(self) ColormapMixIn.__init__(self) diff --git a/silx/gui/plot/items/shape.py b/silx/gui/plot/items/shape.py index 65b26a1..9fc1306 100644 --- a/silx/gui/plot/items/shape.py +++ b/silx/gui/plot/items/shape.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -27,14 +27,16 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "17/05/2017" +__date__ = "21/12/2018" import logging import numpy +import six -from .core import (Item, ColorMixIn, FillMixIn, ItemChangedType) +from ... import colors +from .core import Item, ColorMixIn, FillMixIn, ItemChangedType, LineMixIn _logger = logging.getLogger(__name__) @@ -42,7 +44,7 @@ _logger = logging.getLogger(__name__) # TODO probably make one class for each kind of shape # TODO check fill:polygon/polyline + fill = duplicated -class Shape(Item, ColorMixIn, FillMixIn): +class Shape(Item, ColorMixIn, FillMixIn, LineMixIn): """Description of a shape item :param str type_: The type of shape in: @@ -53,10 +55,12 @@ class Shape(Item, ColorMixIn, FillMixIn): Item.__init__(self) ColorMixIn.__init__(self) FillMixIn.__init__(self) + LineMixIn.__init__(self) self._overlay = False assert type_ in ('hline', 'polygon', 'rectangle', 'vline', 'polylines') self._type = type_ self._points = () + self._lineBgColor = None self._handle = None @@ -71,7 +75,10 @@ class Shape(Item, ColorMixIn, FillMixIn): color=self.getColor(), fill=self.isFill(), overlay=self.isOverlay(), - z=self.getZValue()) + z=self.getZValue(), + linestyle=self.getLineStyle(), + linewidth=self.getLineWidth(), + linebgcolor=self.getLineBgColor()) def isOverlay(self): """Return true if shape is drawn as an overlay @@ -119,3 +126,31 @@ class Shape(Item, ColorMixIn, FillMixIn): """ self._points = numpy.array(points, copy=copy) self._updated(ItemChangedType.DATA) + + def getLineBgColor(self): + """Returns the RGBA color of the item + :rtype: 4-tuple of float in [0, 1] or array of colors + """ + return self._lineBgColor + + def setLineBgColor(self, color, copy=True): + """Set item color + :param color: color(s) to be used + :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or + one of the predefined color names defined in colors.py + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + """ + if color is not None: + if isinstance(color, six.string_types): + color = colors.rgba(color) + else: + color = numpy.array(color, copy=copy) + # TODO more checks + improve color array support + if color.ndim == 1: # Single RGBA color + color = colors.rgba(color) + else: # Array of colors + assert color.ndim == 2 + + self._lineBgColor = color + self._updated(ItemChangedType.LINE_BG_COLOR) diff --git a/silx/gui/plot/matplotlib/Colormap.py b/silx/gui/plot/matplotlib/Colormap.py index 772a473..38f3b55 100644 --- a/silx/gui/plot/matplotlib/Colormap.py +++ b/silx/gui/plot/matplotlib/Colormap.py @@ -29,7 +29,13 @@ from matplotlib.colors import ListedColormap import matplotlib.colors import matplotlib.cm import silx.resources -from silx.utils.deprecation import deprecated +from silx.utils.deprecation import deprecated, deprecated_warning + + +deprecated_warning(type_='module', + name=__file__, + replacement='silx.gui.colors.Colormap', + since_version='0.10.0') _logger = logging.getLogger(__name__) @@ -46,25 +52,30 @@ _CMAPS = {} @property +@deprecated(since_version='0.10.0') def magma(): return getColormap('magma') @property +@deprecated(since_version='0.10.0') def inferno(): return getColormap('inferno') @property +@deprecated(since_version='0.10.0') def plasma(): return getColormap('plasma') @property +@deprecated(since_version='0.10.0') def viridis(): return getColormap('viridis') +@deprecated(since_version='0.10.0') def getColormap(name): """Returns matplotlib colormap corresponding to given name @@ -143,6 +154,7 @@ def getColormap(name): return matplotlib.cm.get_cmap(name) +@deprecated(since_version='0.10.0') def getScalarMappable(colormap, data=None): """Returns matplotlib ScalarMappable corresponding to colormap @@ -223,6 +235,8 @@ def applyColormapToData(data, colormap): return rgbaImage +@deprecated(replacement='silx.colors.Colormap.getSupportedColormaps', + since_version='0.10.0') def getSupportedColormaps(): """Get the supported colormap names as a tuple of str. """ diff --git a/silx/gui/plot/stats/stats.py b/silx/gui/plot/stats/stats.py index a753989..ad61536 100644 --- a/silx/gui/plot/stats/stats.py +++ b/silx/gui/plot/stats/stats.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -30,15 +30,15 @@ __license__ = "MIT" __date__ = "06/06/2018" -import numpy -from silx.gui.plot.items.curve import Curve as CurveItem -from silx.gui.plot.items.image import ImageBase as ImageItem -from silx.gui.plot.items.scatter import Scatter as ScatterItem -from silx.gui.plot.items.histogram import Histogram as HistogramItem -from silx.math.combo import min_max from collections import OrderedDict import logging +import numpy + +from .. import items +from ....math.combo import min_max + + logger = logging.getLogger(__name__) @@ -62,7 +62,7 @@ class Stats(OrderedDict): def calculate(self, item, plot, onlimits): """ - Call all :class:`Stat` object registred and return the result of the + Call all :class:`Stat` object registered and return the result of the computation. :param item: the item for which we want statistics @@ -72,17 +72,29 @@ class Stats(OrderedDict): :return dict: dictionary with :class:`Stat` name as ket and result of the calculation as value """ - res = {} - if isinstance(item, CurveItem): + context = None + # Check for PlotWidget items + if isinstance(item, items.Curve): context = _CurveContext(item, plot, onlimits) - elif isinstance(item, ImageItem): + elif isinstance(item, items.ImageData): context = _ImageContext(item, plot, onlimits) - elif isinstance(item, ScatterItem): + elif isinstance(item, items.Scatter): context = _ScatterContext(item, plot, onlimits) - elif isinstance(item, HistogramItem): + elif isinstance(item, items.Histogram): context = _HistogramContext(item, plot, onlimits) else: - raise ValueError('Item type not managed') + # Check for SceneWidget items + from ...plot3d import items as items3d # Lazy import + + if isinstance(item, (items3d.Scatter2D, items3d.Scatter3D)): + context = _plot3DScatterContext(item, plot, onlimits) + elif isinstance(item, (items3d.ImageData, items3d.ScalarField3D)): + context = _plot3DArrayContext(item, plot, onlimits) + + if context is None: + raise ValueError('Item type not managed') + + res = {} for statName, stat in list(self.items()): if context.kind not in stat.compatibleKinds: logger.debug('kind %s not managed by statistic %s' @@ -124,12 +136,54 @@ class _StatsContext(object): self.min = None self.max = None self.data = None + self.values = None + """The array of data""" + + self.axes = None + """A list of array of position on each axis. + + If the signal is an array, + then each axis has the length of that dimension, + and the order is (z, y, x) (i.e., as the array shape). + If the signal is not an array, + then each axis has the same length as the signal, + and the order is (x, y, z). + """ + self.createContext(item, plot, onlimits) def createContext(self, item, plot, onlimits): raise NotImplementedError("Base class") + def isStructuredData(self): + """Returns True if data as an array-like structure. + + :rtype: bool + """ + if self.values is None or self.axes is None: + return False + + if numpy.prod([len(axis) for axis in self.axes]) == self.values.size: + return True + else: + # Make sure there is the right number of value in axes + for axis in self.axes: + assert len(axis) == self.values.size + return False + + def isScalarData(self): + """Returns True if data is a scalar. + + :rtype: bool + """ + if self.values is None or self.axes is None: + return False + if self.isStructuredData(): + return len(self.axes) == self.values.ndim + else: + return self.values.ndim == 1 + class _CurveContext(_StatsContext): """ @@ -149,8 +203,9 @@ class _CurveContext(_StatsContext): if onlimits: minX, maxX = plot.getXAxis().getLimits() - yData = yData[(minX <= xData) & (xData <= maxX)] - xData = xData[(minX <= xData) & (xData <= maxX)] + mask = (minX <= xData) & (xData <= maxX) + yData = yData[mask] + xData = xData[mask] self.xData = xData self.yData = yData @@ -160,11 +215,12 @@ class _CurveContext(_StatsContext): self.min, self.max = None, None self.data = (xData, yData) self.values = yData + self.axes = (xData,) class _HistogramContext(_StatsContext): """ - StatsContext for :class:`Curve` + StatsContext for :class:`Histogram` :param item: the item for which we want to compute the context :param plot: the plot containing the item @@ -176,12 +232,13 @@ class _HistogramContext(_StatsContext): plot=plot, onlimits=onlimits) def createContext(self, item, plot, onlimits): - xData, edges = item.getData(copy=True)[0:2] - yData = item._revertComputeEdges(x=edges, histogramType=item.getAlignment()) + yData, edges = item.getData(copy=True)[0:2] + xData = item._revertComputeEdges(x=edges, histogramType=item.getAlignment()) if onlimits: minX, maxX = plot.getXAxis().getLimits() - yData = yData[(minX <= xData) & (xData <= maxX)] - xData = xData[(minX <= xData) & (xData <= maxX)] + mask = (minX <= xData) & (xData <= maxX) + yData = yData[mask] + xData = xData[mask] self.xData = xData self.yData = yData @@ -191,11 +248,13 @@ class _HistogramContext(_StatsContext): self.min, self.max = None, None self.data = (xData, yData) self.values = yData + self.axes = (xData,) class _ScatterContext(_StatsContext): - """ - StatsContext for :class:`Scatter` + """StatsContext scatter plots. + + It supports :class:`~silx.gui.plot.items.Scatter`. :param item: the item for which we want to compute the context :param plot: the plot containing the item @@ -207,11 +266,14 @@ class _ScatterContext(_StatsContext): onlimits=onlimits) def createContext(self, item, plot, onlimits): - xData, yData, valueData, xerror, yerror = item.getData(copy=True) - assert plot + valueData = item.getValueData(copy=True) + xData = item.getXData(copy=True) + yData = item.getYData(copy=True) + if onlimits: minX, maxX = plot.getXAxis().getLimits() minY, maxY = plot.getYAxis().getLimits() + # filter on X axis valueData = valueData[(minX <= xData) & (xData <= maxX)] yData = yData[(minX <= xData) & (xData <= maxX)] @@ -220,17 +282,20 @@ class _ScatterContext(_StatsContext): valueData = valueData[(minY <= yData) & (yData <= maxY)] xData = xData[(minY <= yData) & (yData <= maxY)] yData = yData[(minY <= yData) & (yData <= maxY)] + if len(valueData) > 0: self.min, self.max = min_max(valueData) else: self.min, self.max = None, None self.data = (xData, yData, valueData) self.values = valueData + self.axes = (xData, yData) class _ImageContext(_StatsContext): - """ - StatsContext for :class:`ImageBase` + """StatsContext for images. + + It supports :class:`~silx.gui.plot.items.ImageData`. :param item: the item for which we want to compute the context :param plot: the plot containing the item @@ -244,7 +309,8 @@ class _ImageContext(_StatsContext): def createContext(self, item, plot, onlimits): self.origin = item.getOrigin() self.scale = item.getScale() - self.data = item.getData() + + self.data = item.getData(copy=True) if onlimits: minX, maxX = plot.getXAxis().getLimits() @@ -259,25 +325,88 @@ class _ImageContext(_StatsContext): YMinBound = max(YMinBound, 0) if XMaxBound <= XMinBound or YMaxBound <= YMinBound: - return self.noDataSelected() - data = item.getData() - self.data = data[YMinBound:YMaxBound + 1, XMinBound:XMaxBound + 1] - else: - self.data = item.getData() - + self.data = None + else: + self.data = self.data[YMinBound:YMaxBound + 1, + XMinBound:XMaxBound + 1] if self.data.size > 0: self.min, self.max = min_max(self.data) else: self.min, self.max = None, None self.values = self.data + if self.values is not None: + self.axes = (self.origin[1] + self.scale[1] * numpy.arange(self.data.shape[0]), + self.origin[0] + self.scale[0] * numpy.arange(self.data.shape[1])) + + +class _plot3DScatterContext(_StatsContext): + """StatsContext for 3D scatter plots. + + It supports :class:`~silx.gui.plot3d.items.Scatter2D` and + :class:`~silx.gui.plot3d.items.Scatter3D`. + + :param item: the item for which we want to compute the context + :param plot: the plot containing the item + :param bool onlimits: True if we want to apply statistic only on + visible data. + """ + def __init__(self, item, plot, onlimits): + _StatsContext.__init__(self, kind='scatter', item=item, plot=plot, + onlimits=onlimits) + + def createContext(self, item, plot, onlimits): + if onlimits: + raise RuntimeError("Unsupported plot %s" % str(plot)) + + values = item.getValueData(copy=False) + + if values is not None and len(values) > 0: + self.values = values + axes = [item.getXData(copy=False), item.getYData(copy=False)] + if self.values.ndim == 3: + axes.append(item.getZData(copy=False)) + self.axes = tuple(axes) + + self.min, self.max = min_max(self.values) + else: + self.values = None + self.axes = None + self.min, self.max = None, None + + +class _plot3DArrayContext(_StatsContext): + """StatsContext for 3D scalar field and data image. + + It supports :class:`~silx.gui.plot3d.items.ScalarField3D` and + :class:`~silx.gui.plot3d.items.ImageData`. + + :param item: the item for which we want to compute the context + :param plot: the plot containing the item + :param bool onlimits: True if we want to apply statistic only on + visible data. + """ + def __init__(self, item, plot, onlimits): + _StatsContext.__init__(self, kind='image', item=item, plot=plot, + onlimits=onlimits) + + def createContext(self, item, plot, onlimits): + if onlimits: + raise RuntimeError("Unsupported plot %s" % str(plot)) + + values = item.getData(copy=False) + + if values is not None and len(values) > 0: + self.values = values + self.axes = tuple([numpy.arange(size) for size in self.values.shape]) + self.min, self.max = min_max(self.values) + else: + self.values = None + self.axes = None + self.min, self.max = None, None + -BASIC_COMPATIBLE_KINDS = { - 'curve': CurveItem, - 'image': ImageItem, - 'scatter': ScatterItem, - 'histogram': HistogramItem, -} +BASIC_COMPATIBLE_KINDS = 'curve', 'image', 'scatter', 'histogram' class StatBase(object): @@ -285,9 +414,8 @@ class StatBase(object): Base class for defining a statistic. :param str name: the name of the statistic. Must be unique. - :param compatibleKinds: the kind of items (curve, scatter...) for which - the statistic apply. - :rtype: List or tuple + :param List[str] compatibleKinds: + The kind of items (curve, scatter...) for which the statistic apply. """ def __init__(self, name, compatibleKinds=BASIC_COMPATIBLE_KINDS, description=None): self.name = name @@ -298,7 +426,7 @@ class StatBase(object): """ compute the statistic for the given :class:`StatsContext` - :param context: + :param _StatsContext context: :return dict: key is stat name, statistic computed is the dict value """ raise NotImplementedError('Base class') @@ -307,7 +435,7 @@ class StatBase(object): """ If necessary add a tooltip for a stat kind - :param str kinf: the kind of item the statistic is compute for. + :param str kind: the kind of item the statistic is compute for. :return: tooltip or None if no tooltip """ return None @@ -329,17 +457,18 @@ class Stat(StatBase): self._fct = fct def calculate(self, context): - if context.kind in self.compatibleKinds: - return self._fct(context.values) + if context.values is not None: + if context.kind in self.compatibleKinds: + return self._fct(context.values) + else: + raise ValueError('Kind %s not managed by %s' + '' % (context.kind, self.name)) else: - raise ValueError('Kind %s not managed by %s' - '' % (context.kind, self.name)) + return None class StatMin(StatBase): - """ - Compute the minimal value on data - """ + """Compute the minimal value on data""" def __init__(self): StatBase.__init__(self, name='min') @@ -348,9 +477,7 @@ class StatMin(StatBase): class StatMax(StatBase): - """ - Compute the maximal value on data - """ + """Compute the maximal value on data""" def __init__(self): StatBase.__init__(self, name='max') @@ -359,9 +486,7 @@ class StatMax(StatBase): class StatDelta(StatBase): - """ - Compute the delta between minimal and maximal on data - """ + """Compute the delta between minimal and maximal on data""" def __init__(self): StatBase.__init__(self, name='delta') @@ -369,123 +494,84 @@ class StatDelta(StatBase): return context.max - context.min -class StatCoordMin(StatBase): - """ - Compute the first coordinates of the data minimal value - """ +class _StatCoord(StatBase): + """Base class for argmin and argmax stats""" + + def _indexToCoordinates(self, context, index): + """Returns the coordinates of data point at given index + + If data is an array, coordinates are in reverse order from data shape. + + :param _StatsContext context: + :param int index: Index in the flattened data array + :rtype: List[int] + """ + if context.isStructuredData(): + coordinates = [] + for axis in reversed(context.axes): + coordinates.append(axis[index % len(axis)]) + index = index // len(axis) + return tuple(coordinates) + else: + return tuple(axis[index] for axis in context.axes) + + +class StatCoordMin(_StatCoord): + """Compute the coordinates of the first minimum value of the data""" def __init__(self): - StatBase.__init__(self, name='coords min') + _StatCoord.__init__(self, name='coords min') def calculate(self, context): - if context.kind in ('curve', 'histogram'): - return context.xData[numpy.argmin(context.yData)] - elif context.kind == 'scatter': - xData, yData, valueData = context.data - return (xData[numpy.argmin(valueData)], - yData[numpy.argmin(valueData)]) - elif context.kind == 'image': - scaleX, scaleY = context.scale - originX, originY = context.origin - index1D = numpy.argmin(context.data) - ySize = (context.data.shape[1]) - x = index1D % context.data.shape[1] - y = (index1D - x) / ySize - x = x * scaleX + originX - y = y * scaleY + originY - return (x, y) - else: - raise ValueError('kind not managed') + if context.values is None or not context.isScalarData(): + return None + + index = numpy.argmin(context.values) + return self._indexToCoordinates(context, index) def getToolTip(self, kind): - if kind in ('scatter', 'image'): - return '(x, y)' - else: - return None + return "Coordinates of the first minimum value of the data" -class StatCoordMax(StatBase): - """ - Compute the first coordinates of the data minimal value - """ + +class StatCoordMax(_StatCoord): + """Compute the coordinates of the first maximum value of the data""" def __init__(self): - StatBase.__init__(self, name='coords max') + _StatCoord.__init__(self, name='coords max') def calculate(self, context): - if context.kind in ('curve', 'histogram'): - return context.xData[numpy.argmax(context.yData)] - elif context.kind == 'scatter': - xData, yData, valueData = context.data - return (xData[numpy.argmax(valueData)], - yData[numpy.argmax(valueData)]) - elif context.kind == 'image': - scaleX, scaleY = context.scale - originX, originY = context.origin - index1D = numpy.argmax(context.data) - ySize = (context.data.shape[1]) - x = index1D % context.data.shape[1] - y = (index1D - x) / ySize - x = x * scaleX + originX - y = y * scaleY + originY - return (x, y) - else: - raise ValueError('kind not managed') + if context.values is None or not context.isScalarData(): + return None + + index = numpy.argmax(context.values) + return self._indexToCoordinates(context, index) def getToolTip(self, kind): - if kind in ('scatter', 'image'): - return '(x, y)' - else: - return None + return "Coordinates of the first maximum value of the data" + class StatCOM(StatBase): - """ - Compute data center of mass - """ + """Compute data center of mass""" def __init__(self): StatBase.__init__(self, name='COM', description='Center of mass') def calculate(self, context): - if context.kind in ('curve', 'histogram'): - xData, yData = context.data - deno = numpy.sum(yData).astype(numpy.float32) - if deno == 0.: - return numpy.nan - else: - return numpy.sum(xData * yData).astype(numpy.float32) / deno - elif context.kind == 'scatter': - xData, yData, values = context.data - deno = numpy.sum(values).astype(numpy.float32) - if deno == 0.: - return numpy.nan, numpy.nan - else: - xcom = numpy.sum(xData * values).astype(numpy.float32) / deno - ycom = numpy.sum(yData * values).astype(numpy.float32) / deno - return (xcom, ycom) - elif context.kind == 'image': - yData = numpy.sum(context.data, axis=1) - xData = numpy.sum(context.data, axis=0) - dataXRange = range(context.data.shape[1]) - dataYRange = range(context.data.shape[0]) - xScale, yScale = context.scale - xOrigin, yOrigin = context.origin - - denoY = numpy.sum(yData) - if denoY == 0.: - ycom = numpy.nan - else: - ycom = numpy.sum(yData * dataYRange) / denoY - ycom = ycom * yScale + yOrigin + if context.values is None or not context.isScalarData(): + return None - denoX = numpy.sum(xData) - if denoX == 0.: - xcom = numpy.nan - else: - xcom = numpy.sum(xData * dataXRange) / denoX - xcom = xcom * xScale + xOrigin - return (xcom, ycom) + values = numpy.array(context.values, dtype=numpy.float64) + sum_ = numpy.sum(values) + if sum_ == 0.: + return (numpy.nan,) * len(context.axes) + + if context.isStructuredData(): + centerofmass = [] + for index, axis in enumerate(context.axes): + axes = tuple([i for i in range(len(context.axes)) if i != index]) + centerofmass.append( + numpy.sum(axis * numpy.sum(values, axis=axes)) / sum_) + return tuple(reversed(centerofmass)) else: - raise ValueError('kind not managed') + return tuple( + numpy.sum(axis * values) / sum_ for axis in context.axes) def getToolTip(self, kind): - if kind in ('scatter', 'image'): - return '(x, y)' - else: - return None + return "Compute the center of mass of the dataset" diff --git a/silx/gui/plot/stats/statshandler.py b/silx/gui/plot/stats/statshandler.py index 0a62b31..f69daff 100644 --- a/silx/gui/plot/stats/statshandler.py +++ b/silx/gui/plot/stats/statshandler.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -45,7 +45,14 @@ class _FloatItem(qt.QTableWidgetItem): qt.QTableWidgetItem.__init__(self, type=type) def __lt__(self, other): - return float(self.text()) < float(other.text()) + self_values = self.text().lstrip('(').rstrip(')').split(',') + other_values = other.text().lstrip('(').rstrip(')').split(',') + for self_value, other_value in zip(self_values, other_values): + f_self_value = float(self_value) + f_other_value = float(other_value) + if f_self_value != f_other_value: + return f_self_value < f_other_value + return False class StatFormatter(object): @@ -89,10 +96,60 @@ class StatsHandler(object): self.stats = statsmdl.Stats() self.formatters = {} for elmt in statFormatters: - helper = _StatHelper(elmt) - self.add(stat=helper.stat, formatter=helper.statFormatter) + stat, formatter = self._processStatArgument(elmt) + self.add(stat=stat, formatter=formatter) + + @staticmethod + def _processStatArgument(arg): + """Process an element of the init arguments + + :param arg: The argument to process + :return: Corresponding (StatBase, StatFormatter) + """ + stat, formatter = None, None + + if isinstance(arg, statsmdl.StatBase): + stat = arg + else: + assert len(arg) > 0 + if isinstance(arg[0], statsmdl.StatBase): + stat = arg[0] + if len(arg) > 2: + raise ValueError('To many argument with %s. At most one ' + 'argument can be associated with the ' + 'BaseStat (the `StatFormatter`') + if len(arg) == 2: + assert arg[1] is None or isinstance(arg[1], (StatFormatter, str)) + formatter = arg[1] + else: + if isinstance(arg[0], tuple): + if len(arg) > 1: + formatter = arg[1] + arg = arg[0] + + if type(arg[0]) is not str: + raise ValueError('first element of the tuple should be a string' + ' or a StatBase instance') + if len(arg) == 1: + raise ValueError('A function should be associated with the' + 'stat name') + if len(arg) > 3: + raise ValueError('Two much argument given for defining statistic.' + 'Take at most three arguments (name, function, ' + 'kinds)') + if len(arg) == 2: + stat = statsmdl.Stat(name=arg[0], fct=arg[1]) + else: + stat = statsmdl.Stat(name=arg[0], fct=arg[1], kinds=arg[2]) + + return stat, formatter def add(self, stat, formatter=None): + """Add a stat to the list. + + :param StatBase stat: + :param Union[None,StatFormatter] formatter: + """ assert isinstance(stat, statsmdl.StatBase) self.stats.add(stat) _formatter = formatter @@ -101,9 +158,9 @@ class StatsHandler(object): self.formatters[stat.name] = _formatter def format(self, name, val): - """ - Apply the format for the `name` statistic and the given value - :param name: the name of the associated statistic + """Apply the format for the `name` statistic and the given value + + :param str name: the name of the associated statistic :param val: value before formatting :return: formatted value """ @@ -123,7 +180,7 @@ class StatsHandler(object): def calculate(self, item, plot, onlimits): """ - compute all statistic registred and return the list of formatted + compute all statistic registered and return the list of formatted statistics result. :param item: item for which we want to compute statistics @@ -137,54 +194,3 @@ class StatsHandler(object): for resName, resValue in list(res.items()): res[resName] = self.format(resName, res[resName]) return res - - -class _StatHelper(object): - """ - Helper class to generated the requested StatBase instance and the - associated StatFormatter - """ - def __init__(self, arg): - self.statFormatter = None - self.stat = None - - if isinstance(arg, statsmdl.StatBase): - self.stat = arg - else: - assert len(arg) > 0 - if isinstance(arg[0], statsmdl.StatBase): - self.dealWithStatAndFormatter(arg) - else: - _arg = arg - if isinstance(arg[0], tuple): - _arg = arg[0] - if len(arg) > 1: - self.statFormatter = arg[1] - self.createStatInstanceAndFormatter(_arg) - - def dealWithStatAndFormatter(self, arg): - assert isinstance(arg[0], statsmdl.StatBase) - self.stat = arg[0] - if len(arg) > 2: - raise ValueError('To many argument with %s. At most one ' - 'argument can be associated with the ' - 'BaseStat (the `StatFormatter`') - if len(arg) is 2: - assert isinstance(arg[1], (StatFormatter, type(None), str)) - self.statFormatter = arg[1] - - def createStatInstanceAndFormatter(self, arg): - if type(arg[0]) is not str: - raise ValueError('first element of the tuple should be a string' - ' or a StatBase instance') - if len(arg) is 1: - raise ValueError('A function should be associated with the' - 'stat name') - if len(arg) > 3: - raise ValueError('Two much argument given for defining statistic.' - 'Take at most three arguments (name, function, ' - 'kinds)') - if len(arg) is 2: - self.stat = statsmdl.Stat(name=arg[0], fct=arg[1]) - else: - self.stat = statsmdl.Stat(name=arg[0], fct=arg[1], kinds=arg[2]) diff --git a/silx/gui/plot/test/testCurvesROIWidget.py b/silx/gui/plot/test/testCurvesROIWidget.py index 0704779..5bcabd8 100644 --- a/silx/gui/plot/test/testCurvesROIWidget.py +++ b/silx/gui/plot/test/testCurvesROIWidget.py @@ -36,7 +36,7 @@ from collections import OrderedDict import numpy from silx.gui import qt from silx.test.utils import temp_dir -from silx.gui.utils.testutils import TestCaseQt +from silx.gui.utils.testutils import TestCaseQt, SignalListener from silx.gui.plot import PlotWindow, CurvesROIWidget @@ -52,7 +52,8 @@ class TestCurvesROIWidget(TestCaseQt): self.plot.show() self.qWaitForWindowExposed(self.plot) - self.widget = CurvesROIWidget.CurvesROIDockWidget(plot=self.plot, name='TEST') + self.widget = self.plot.getCurvesRoiDockWidget() + self.widget.show() self.qWaitForWindowExposed(self.widget) @@ -67,10 +68,6 @@ class TestCurvesROIWidget(TestCaseQt): super(TestCurvesROIWidget, self).tearDown() - def testEmptyPlot(self): - """Empty plot, display ROI widget""" - pass - def testWithCurves(self): """Plot with curves: test all ROI widget buttons""" for offset in range(2): @@ -80,13 +77,16 @@ class TestCurvesROIWidget(TestCaseQt): # Add two ROI self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton) + self.qWait(200) self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton) + self.qWait(200) # Change active curve self.plot.setActiveCurve(str(1)) # Delete a ROI self.mouseClick(self.widget.roiWidget.delButton, qt.Qt.LeftButton) + self.qWait(200) with temp_dir() as tmpDir: self.tmpFile = os.path.join(tmpDir, 'test.ini') @@ -94,30 +94,42 @@ class TestCurvesROIWidget(TestCaseQt): # Save ROIs self.widget.roiWidget.save(self.tmpFile) self.assertTrue(os.path.isfile(self.tmpFile)) + self.assertTrue(len(self.widget.getRois()) is 2) # Reset ROIs self.mouseClick(self.widget.roiWidget.resetButton, qt.Qt.LeftButton) + self.qWait(200) + rois = self.widget.getRois() + self.assertTrue(len(rois) is 1) + print(rois) + roiID = list(rois.keys())[0] + self.assertTrue(rois[roiID].getName() == 'ICR') # Load ROIs self.widget.roiWidget.load(self.tmpFile) + self.assertTrue(len(self.widget.getRois()) is 2) del self.tmpFile def testMiddleMarker(self): """Test with middle marker enabled""" - self.widget.roiWidget.setMiddleROIMarkerFlag(True) + self.widget.roiWidget.roiTable.setMiddleROIMarkerFlag(True) # Add a ROI self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton) - xleftMarker = self.plot._getMarker(legend='ROI min').getXPosition() - xMiddleMarker = self.plot._getMarker(legend='ROI middle').getXPosition() - xRightMarker = self.plot._getMarker(legend='ROI max').getXPosition() - self.assertAlmostEqual(xMiddleMarker, - xleftMarker + (xRightMarker - xleftMarker) / 2.) - - def testCalculation(self): + for roiID in self.widget.roiWidget.roiTable._markersHandler._roiMarkerHandlers: + handler = self.widget.roiWidget.roiTable._markersHandler._roiMarkerHandlers[roiID] + assert handler.getMarker('min') + xleftMarker = handler.getMarker('min').getXPosition() + xMiddleMarker = handler.getMarker('middle').getXPosition() + xRightMarker = handler.getMarker('max').getXPosition() + thValue = xleftMarker + (xRightMarker - xleftMarker) / 2. + self.assertAlmostEqual(xMiddleMarker, thValue) + + def testAreaCalculation(self): + """Test result of area calculation""" x = numpy.arange(100.) y = numpy.arange(100.) @@ -129,30 +141,60 @@ class TestCurvesROIWidget(TestCaseQt): self.plot.setActiveCurve("positive") # Add two ROIs - ddict = {} - ddict["positive"] = {"from": 10, "to": 20, "type":"X"} - ddict["negative"] = {"from": -20, "to": -10, "type":"X"} - self.widget.roiWidget.setRois(ddict) + roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20, + todata=-10, type_='X') + roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10, + todata=20, type_='X') + + self.widget.roiWidget.setRois((roi_pos, roi_neg)) + + posCurve = self.plot.getCurve('positive') + negCurve = self.plot.getCurve('negative') + + self.assertEqual(roi_pos.computeRawAndNetArea(posCurve), + (numpy.trapz(y=[10, 20], x=[10, 20]), + 0.0)) + self.assertEqual(roi_pos.computeRawAndNetArea(negCurve), + (0.0, 0.0)) + self.assertEqual(roi_neg.computeRawAndNetArea(posCurve), + ((0.0), 0.0)) + self.assertEqual(roi_neg.computeRawAndNetArea(negCurve), + ((-150.0), 0.0)) + + def testCountsCalculation(self): + """Test result of count calculation""" + x = numpy.arange(100.) + y = numpy.arange(100.) - # And calculate the expected output - self.widget.calculateROIs() + # Add two curves + self.plot.addCurve(x, y, legend="positive") + self.plot.addCurve(-x, y, legend="negative") + + # Make sure there is an active curve and it is the positive one + self.plot.setActiveCurve("positive") - output = self.widget.roiWidget.getRois() - self.assertEqual(output["positive"]["rawcounts"], - y[ddict["positive"]["from"]:ddict["positive"]["to"]+1].sum(), - "Calculation failed on positive X coordinates") + # Add two ROIs + roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20, + todata=-10, type_='X') + roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10, + todata=20, type_='X') + + self.widget.roiWidget.setRois((roi_pos, roi_neg)) - # Set the curve with negative X coordinates as active - self.plot.setActiveCurve("negative") + posCurve = self.plot.getCurve('positive') + negCurve = self.plot.getCurve('negative') - # the ROIs should have been automatically updated - output = self.widget.roiWidget.getRois() - selection = numpy.nonzero((-x >= output["negative"]["from"]) & \ - (-x <= output["negative"]["to"]))[0] - self.assertEqual(output["negative"]["rawcounts"], - y[selection].sum(), "Calculation failed on negative X coordinates") + self.assertEqual(roi_pos.computeRawAndNetCounts(posCurve), + (y[10:21].sum(), 0.0)) + self.assertEqual(roi_pos.computeRawAndNetCounts(negCurve), + (0.0, 0.0)) + self.assertEqual(roi_neg.computeRawAndNetCounts(posCurve), + ((0.0), 0.0)) + self.assertEqual(roi_neg.computeRawAndNetCounts(negCurve), + (y[10:21].sum(), 0.0)) def testDeferedInit(self): + """Test behavior of the deferedInit""" x = numpy.arange(100.) y = numpy.arange(100.) self.plot.addCurve(x=x, y=y, legend="name", replace="True") @@ -164,12 +206,123 @@ class TestCurvesROIWidget(TestCaseQt): ]) roiWidget = self.plot.getCurvesRoiDockWidget().roiWidget - self.assertFalse(roiWidget._isInit) self.plot.getCurvesRoiDockWidget().setRois(roisDefs) self.assertTrue(len(roiWidget.getRois()) is len(roisDefs)) self.plot.getCurvesRoiDockWidget().setVisible(True) self.assertTrue(len(roiWidget.getRois()) is len(roisDefs)) + def testDictCompatibility(self): + """Test that ROI api is valid with dict and not information is lost""" + roiDict = {'from': 20, 'to': 200, 'type': 'energy', 'comment': 'no', + 'name': 'myROI', 'calibration': [1, 2, 3]} + roi = CurvesROIWidget.ROI._fromDict(roiDict) + self.assertTrue(roi.toDict() == roiDict) + + def testShowAllROI(self): + """Test the show allROI action""" + x = numpy.arange(100.) + y = numpy.arange(100.) + self.plot.addCurve(x=x, y=y, legend="name", replace="True") + + roisDefsDict = { + "range1": {"from": 20, "to": 200,"type": "energy"}, + "range2": {"from": 300, "to": 500, "type": "energy"} + } + + roisDefsObj = ( + CurvesROIWidget.ROI(name='range3', fromdata=20, todata=200, + type_='energy'), + CurvesROIWidget.ROI(name='range4', fromdata=300, todata=500, + type_='energy') + ) + self.widget.roiWidget.showAllMarkers(True) + roiWidget = self.plot.getCurvesRoiDockWidget().roiWidget + roiWidget.setRois(roisDefsDict) + self.assertTrue(len(self.plot._getAllMarkers()) is 2*3) + + markersHandler = self.widget.roiWidget.roiTable._markersHandler + roiWidget.showAllMarkers(True) + ICRROI = markersHandler.getVisibleRois() + self.assertTrue(len(ICRROI) is 2) + + roiWidget.showAllMarkers(False) + ICRROI = markersHandler.getVisibleRois() + self.assertTrue(len(ICRROI) is 1) + + roiWidget.setRois(roisDefsObj) + self.qapp.processEvents() + self.assertTrue(len(self.plot._getAllMarkers()) is 2*3) + + markersHandler = self.widget.roiWidget.roiTable._markersHandler + roiWidget.showAllMarkers(True) + ICRROI = markersHandler.getVisibleRois() + self.assertTrue(len(ICRROI) is 2) + + roiWidget.showAllMarkers(False) + ICRROI = markersHandler.getVisibleRois() + self.assertTrue(len(ICRROI) is 1) + + def testRoiEdition(self): + """Make sure if the ROI object is edited the ROITable will be updated + """ + roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5) + self.widget.roiWidget.setRois((roi, )) + + x = (0, 1, 1, 2, 2, 3) + y = (1, 1, 2, 2, 1, 1) + self.plot.addCurve(x=x, y=y, legend='linearCurve') + self.plot.setActiveCurve(legend='linearCurve') + self.widget.calculateROIs() + + roiTable = self.widget.roiWidget.roiTable + indexesColumns = CurvesROIWidget.ROITable.COLUMNS_INDEX + itemRawCounts = roiTable.item(0, indexesColumns['Raw Counts']) + itemNetCounts = roiTable.item(0, indexesColumns['Net Counts']) + + self.assertTrue(itemRawCounts.text() == '8.0') + self.assertTrue(itemNetCounts.text() == '2.0') + + itemRawArea = roiTable.item(0, indexesColumns['Raw Area']) + itemNetArea = roiTable.item(0, indexesColumns['Net Area']) + + self.assertTrue(itemRawArea.text() == '4.0') + self.assertTrue(itemNetArea.text() == '1.0') + + roi.setTo(2) + itemRawArea = roiTable.item(0, indexesColumns['Raw Area']) + self.assertTrue(itemRawArea.text() == '3.0') + roi.setFrom(1) + itemRawArea = roiTable.item(0, indexesColumns['Raw Area']) + self.assertTrue(itemRawArea.text() == '2.0') + + def testRemoveActiveROI(self): + """Test widget behavior when removing the active ROI""" + roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5) + self.widget.roiWidget.setRois((roi,)) + + self.widget.roiWidget.roiTable.setActiveRoi(None) + self.assertTrue(len(self.widget.roiWidget.roiTable.selectedItems()) is 0) + self.widget.roiWidget.setRois((roi,)) + self.plot.setActiveCurve(legend='linearCurve') + self.widget.calculateROIs() + + def testEmitCurrentROI(self): + """Test behavior of the CurvesROIWidget.sigROISignal""" + roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5) + self.widget.roiWidget.setRois((roi,)) + signalListener = SignalListener() + self.widget.roiWidget.sigROISignal.connect(signalListener.partial()) + self.widget.show() + self.qapp.processEvents() + self.assertTrue(signalListener.callCount() is 0) + self.assertTrue(self.widget.roiWidget.roiTable.activeRoi is roi) + roi.setFrom(0.0) + self.qapp.processEvents() + self.assertTrue(signalListener.callCount() is 0) + roi.setFrom(0.3) + self.qapp.processEvents() + self.assertTrue(signalListener.callCount() is 1) + def suite(): test_suite = unittest.TestSuite() diff --git a/silx/gui/plot/test/testMaskToolsWidget.py b/silx/gui/plot/test/testMaskToolsWidget.py index 6912ea3..a05c1be 100644 --- a/silx/gui/plot/test/testMaskToolsWidget.py +++ b/silx/gui/plot/test/testMaskToolsWidget.py @@ -42,10 +42,7 @@ from silx.gui.utils.testutils import getQToolButtonFromAction from silx.gui.plot import PlotWindow, MaskToolsWidget from .utils import PlotWidgetTestCase -try: - import fabio -except ImportError: - fabio = None +import fabio _logger = logging.getLogger(__name__) @@ -254,8 +251,6 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase): self.__loadSave("npy") def testLoadSaveFit2D(self): - if fabio is None: - self.skipTest("Fabio is missing") self.__loadSave("msk") def testSigMaskChangedEmitted(self): diff --git a/silx/gui/plot/test/testPlotWidget.py b/silx/gui/plot/test/testPlotWidget.py index 857b9bc..9d7c093 100644 --- a/silx/gui/plot/test/testPlotWidget.py +++ b/silx/gui/plot/test/testPlotWidget.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -26,7 +26,7 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "21/09/2018" +__date__ = "03/01/2019" import unittest @@ -36,8 +36,6 @@ import numpy from silx.utils.testutils import ParametricTestCase, parameterize from silx.gui.utils.testutils import SignalListener from silx.gui.utils.testutils import TestCaseQt -from silx.utils import testutils -from silx.utils import deprecation from silx.test.utils import test_options @@ -184,6 +182,39 @@ class TestPlotWidget(PlotWidgetTestCase, ParametricTestCase): self.assertTrue(numpy.all(numpy.equal(items[4].getPosition()[0], marker_x))) self.assertEqual(items[5].getType(), 'rectangle') + def testBackGroundColors(self): + self.plot.setVisible(True) + self.qWaitForWindowExposed(self.plot) + self.qapp.processEvents() + + # Custom the full background + color = self.plot.getBackgroundColor() + self.assertTrue(color.isValid()) + self.assertEqual(color, qt.QColor(255, 255, 255)) + self.plot.setBackgroundColor("red") + color = self.plot.getBackgroundColor() + self.assertTrue(color.isValid()) + self.qapp.processEvents() + + # Custom the data background + color = self.plot.getDataBackgroundColor() + self.assertFalse(color.isValid()) + self.plot.setDataBackgroundColor("red") + color = self.plot.getDataBackgroundColor() + self.assertTrue(color.isValid()) + self.qapp.processEvents() + + # Back to default + self.plot.setBackgroundColor('white') + self.plot.setDataBackgroundColor(None) + color = self.plot.getBackgroundColor() + self.assertTrue(color.isValid()) + self.assertEqual(color, qt.QColor(255, 255, 255)) + color = self.plot.getDataBackgroundColor() + self.assertFalse(color.isValid()) + self.qapp.processEvents() + + class TestPlotImage(PlotWidgetTestCase, ParametricTestCase): """Basic tests for addImage""" @@ -881,17 +912,12 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase): if getter is not None: self.assertEqual(getter(), expected) - @testutils.test_logging(deprecation.depreclog.name) def testOldPlotAxis_Logarithmic(self): """Test silx API prior to silx 0.6""" x = self.plot.getXAxis() y = self.plot.getYAxis() yright = self.plot.getYAxis(axis="right") - listener = SignalListener() - self.plot.sigSetXAxisLogarithmic.connect(listener.partial("x")) - self.plot.sigSetYAxisLogarithmic.connect(listener.partial("y")) - self.assertEqual(x.getScale(), x.LINEAR) self.assertEqual(y.getScale(), x.LINEAR) self.assertEqual(yright.getScale(), x.LINEAR) @@ -902,7 +928,6 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase): self.assertEqual(yright.getScale(), x.LINEAR) self.assertEqual(self.plot.isXAxisLogarithmic(), True) self.assertEqual(self.plot.isYAxisLogarithmic(), False) - self.assertEqual(listener.arguments(callIndex=-1), ("x", True)) self.plot.setYAxisLogarithmic(True) self.assertEqual(x.getScale(), x.LOGARITHMIC) @@ -910,7 +935,6 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase): self.assertEqual(yright.getScale(), x.LOGARITHMIC) self.assertEqual(self.plot.isXAxisLogarithmic(), True) self.assertEqual(self.plot.isYAxisLogarithmic(), True) - self.assertEqual(listener.arguments(callIndex=-1), ("y", True)) yright.setScale(yright.LINEAR) self.assertEqual(x.getScale(), x.LOGARITHMIC) @@ -918,19 +942,13 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase): self.assertEqual(yright.getScale(), x.LINEAR) self.assertEqual(self.plot.isXAxisLogarithmic(), True) self.assertEqual(self.plot.isYAxisLogarithmic(), False) - self.assertEqual(listener.arguments(callIndex=-1), ("y", False)) - @testutils.test_logging(deprecation.depreclog.name) def testOldPlotAxis_AutoScale(self): """Test silx API prior to silx 0.6""" x = self.plot.getXAxis() y = self.plot.getYAxis() yright = self.plot.getYAxis(axis="right") - listener = SignalListener() - self.plot.sigSetXAxisAutoScale.connect(listener.partial("x")) - self.plot.sigSetYAxisAutoScale.connect(listener.partial("y")) - self.assertEqual(x.isAutoScale(), True) self.assertEqual(y.isAutoScale(), True) self.assertEqual(yright.isAutoScale(), True) @@ -941,7 +959,6 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase): self.assertEqual(yright.isAutoScale(), True) self.assertEqual(self.plot.isXAxisAutoScale(), False) self.assertEqual(self.plot.isYAxisAutoScale(), True) - self.assertEqual(listener.arguments(callIndex=-1), ("x", False)) self.plot.setYAxisAutoScale(False) self.assertEqual(x.isAutoScale(), False) @@ -949,7 +966,6 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase): self.assertEqual(yright.isAutoScale(), False) self.assertEqual(self.plot.isXAxisAutoScale(), False) self.assertEqual(self.plot.isYAxisAutoScale(), False) - self.assertEqual(listener.arguments(callIndex=-1), ("y", False)) yright.setAutoScale(True) self.assertEqual(x.isAutoScale(), False) @@ -957,18 +973,13 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase): self.assertEqual(yright.isAutoScale(), True) self.assertEqual(self.plot.isXAxisAutoScale(), False) self.assertEqual(self.plot.isYAxisAutoScale(), True) - self.assertEqual(listener.arguments(callIndex=-1), ("y", True)) - @testutils.test_logging(deprecation.depreclog.name) def testOldPlotAxis_Inverted(self): """Test silx API prior to silx 0.6""" x = self.plot.getXAxis() y = self.plot.getYAxis() yright = self.plot.getYAxis(axis="right") - listener = SignalListener() - self.plot.sigSetYAxisInverted.connect(listener.partial("y")) - self.assertEqual(x.isInverted(), False) self.assertEqual(y.isInverted(), False) self.assertEqual(yright.isInverted(), False) @@ -978,14 +989,12 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase): self.assertEqual(y.isInverted(), True) self.assertEqual(yright.isInverted(), True) self.assertEqual(self.plot.isYAxisInverted(), True) - self.assertEqual(listener.arguments(callIndex=-1), ("y", True)) yright.setInverted(False) self.assertEqual(x.isInverted(), False) self.assertEqual(y.isInverted(), False) self.assertEqual(yright.isInverted(), False) self.assertEqual(self.plot.isYAxisInverted(), False) - self.assertEqual(listener.arguments(callIndex=-1), ("y", False)) def testLogXWithData(self): self.plot.setGraphTitle('Curve X: Log Y: Linear') diff --git a/silx/gui/plot/test/testSaveAction.py b/silx/gui/plot/test/testSaveAction.py index 85669bf..0eb129d 100644 --- a/silx/gui/plot/test/testSaveAction.py +++ b/silx/gui/plot/test/testSaveAction.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -106,12 +106,30 @@ class TestSaveActionExtension(PlotWidgetTestCase): self.assertEqual(saveAction.getFileFilters('all')[nameFilter], self._dummySaveFunction) + # Add a new file filter at a particular position + nameFilter = 'Dummy file2 (*.dummy)' + saveAction.setFileFilter('all', nameFilter, + self._dummySaveFunction, index=3) + self.assertTrue(nameFilter in saveAction.getFileFilters('all')) + filters = saveAction.getFileFilters('all') + self.assertEqual(filters[nameFilter], self._dummySaveFunction) + self.assertEqual(list(filters.keys()).index(nameFilter),3) + # Update an existing file filter nameFilter = SaveAction.IMAGE_FILTER_EDF saveAction.setFileFilter('image', nameFilter, self._dummySaveFunction) self.assertEqual(saveAction.getFileFilters('image')[nameFilter], self._dummySaveFunction) + # Change the position of an existing file filter + nameFilter = 'Dummy file2 (*.dummy)' + oldIndex = list(saveAction.getFileFilters('all')).index(nameFilter) + newIndex = oldIndex - 1 + saveAction.setFileFilter('all', nameFilter, + self._dummySaveFunction, index=newIndex) + filters = saveAction.getFileFilters('all') + self.assertEqual(filters[nameFilter], self._dummySaveFunction) + self.assertEqual(list(filters.keys()).index(nameFilter), newIndex) def suite(): test_suite = unittest.TestSuite() diff --git a/silx/gui/plot/test/testScatterMaskToolsWidget.py b/silx/gui/plot/test/testScatterMaskToolsWidget.py index a446911..171ec42 100644 --- a/silx/gui/plot/test/testScatterMaskToolsWidget.py +++ b/silx/gui/plot/test/testScatterMaskToolsWidget.py @@ -42,10 +42,7 @@ from silx.gui.utils.testutils import getQToolButtonFromAction from silx.gui.plot import PlotWindow, ScatterMaskToolsWidget from .utils import PlotWidgetTestCase -try: - import fabio -except ImportError: - fabio = None +import fabio _logger = logging.getLogger(__name__) diff --git a/silx/gui/plot/test/testStats.py b/silx/gui/plot/test/testStats.py index faedcff..7fbc247 100644 --- a/silx/gui/plot/test/testStats.py +++ b/silx/gui/plot/test/testStats.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -112,34 +112,34 @@ class TestStats(TestCaseQt): """Test result for simple stats on a curve""" _stats = self.getBasicStats() xData = yData = numpy.array(range(20)) - self.assertTrue(_stats['min'].calculate(self.curveContext) == 0) - self.assertTrue(_stats['max'].calculate(self.curveContext) == 19) - self.assertTrue(_stats['minCoords'].calculate(self.curveContext) == [0]) - self.assertTrue(_stats['maxCoords'].calculate(self.curveContext) == [19]) - self.assertTrue(_stats['std'].calculate(self.curveContext) == numpy.std(yData)) - self.assertTrue(_stats['mean'].calculate(self.curveContext) == numpy.mean(yData)) + self.assertEqual(_stats['min'].calculate(self.curveContext), 0) + self.assertEqual(_stats['max'].calculate(self.curveContext), 19) + self.assertEqual(_stats['minCoords'].calculate(self.curveContext), (0,)) + self.assertEqual(_stats['maxCoords'].calculate(self.curveContext), (19,)) + self.assertEqual(_stats['std'].calculate(self.curveContext), numpy.std(yData)) + self.assertEqual(_stats['mean'].calculate(self.curveContext), numpy.mean(yData)) com = numpy.sum(xData * yData) / numpy.sum(yData) - self.assertTrue(_stats['com'].calculate(self.curveContext) == com) + self.assertEqual(_stats['com'].calculate(self.curveContext), com) def testBasicStatsImage(self): """Test result for simple stats on an image""" _stats = self.getBasicStats() - self.assertTrue(_stats['min'].calculate(self.imageContext) == 0) - self.assertTrue(_stats['max'].calculate(self.imageContext) == 128 * 32 - 1) - self.assertTrue(_stats['minCoords'].calculate(self.imageContext) == (0, 0)) - self.assertTrue(_stats['maxCoords'].calculate(self.imageContext) == (127, 31)) - self.assertTrue(_stats['std'].calculate(self.imageContext) == numpy.std(self.imageData)) - self.assertTrue(_stats['mean'].calculate(self.imageContext) == numpy.mean(self.imageData)) - - yData = numpy.sum(self.imageData, axis=1) - xData = numpy.sum(self.imageData, axis=0) + self.assertEqual(_stats['min'].calculate(self.imageContext), 0) + self.assertEqual(_stats['max'].calculate(self.imageContext), 128 * 32 - 1) + self.assertEqual(_stats['minCoords'].calculate(self.imageContext), (0, 0)) + self.assertEqual(_stats['maxCoords'].calculate(self.imageContext), (127, 31)) + self.assertEqual(_stats['std'].calculate(self.imageContext), numpy.std(self.imageData)) + self.assertEqual(_stats['mean'].calculate(self.imageContext), numpy.mean(self.imageData)) + + yData = numpy.sum(self.imageData.astype(numpy.float64), axis=1) + xData = numpy.sum(self.imageData.astype(numpy.float64), axis=0) dataXRange = range(self.imageData.shape[1]) dataYRange = range(self.imageData.shape[0]) ycom = numpy.sum(yData*dataYRange) / numpy.sum(yData) xcom = numpy.sum(xData*dataXRange) / numpy.sum(xData) - self.assertTrue(_stats['com'].calculate(self.imageContext) == (xcom, ycom)) + self.assertEqual(_stats['com'].calculate(self.imageContext), (xcom, ycom)) def testStatsImageAdv(self): """Test that scale and origin are taking into account for images""" @@ -153,52 +153,46 @@ class TestStats(TestCaseQt): onlimits=False ) _stats = self.getBasicStats() - self.assertTrue(_stats['min'].calculate(image2Context) == 0) - self.assertTrue( - _stats['max'].calculate(image2Context) == 128 * 32 - 1) - self.assertTrue( - _stats['minCoords'].calculate(image2Context) == (100, 10)) - self.assertTrue( - _stats['maxCoords'].calculate(image2Context) == (127*2. + 100, - 31 * 0.5 + 10) - ) - self.assertTrue( - _stats['std'].calculate(image2Context) == numpy.std( - self.imageData)) - self.assertTrue( - _stats['mean'].calculate(image2Context) == numpy.mean( - self.imageData)) + self.assertEqual(_stats['min'].calculate(image2Context), 0) + self.assertEqual( + _stats['max'].calculate(image2Context), 128 * 32 - 1) + self.assertEqual( + _stats['minCoords'].calculate(image2Context), (100, 10)) + self.assertEqual( + _stats['maxCoords'].calculate(image2Context), (127*2. + 100, + 31 * 0.5 + 10)) + self.assertEqual(_stats['std'].calculate(image2Context), + numpy.std(self.imageData)) + self.assertEqual(_stats['mean'].calculate(image2Context), + numpy.mean(self.imageData)) yData = numpy.sum(self.imageData, axis=1) xData = numpy.sum(self.imageData, axis=0) - dataXRange = range(self.imageData.shape[1]) - dataYRange = range(self.imageData.shape[0]) + dataXRange = numpy.arange(self.imageData.shape[1], dtype=numpy.float64) + dataYRange = numpy.arange(self.imageData.shape[0], dtype=numpy.float64) ycom = numpy.sum(yData * dataYRange) / numpy.sum(yData) ycom = (ycom * 0.5) + 10 xcom = numpy.sum(xData * dataXRange) / numpy.sum(xData) xcom = (xcom * 2.) + 100 - self.assertTrue( - _stats['com'].calculate(image2Context) == (xcom, ycom)) + self.assertTrue(numpy.allclose( + _stats['com'].calculate(image2Context), (xcom, ycom))) def testBasicStatsScatter(self): """Test result for simple stats on a scatter""" _stats = self.getBasicStats() - self.assertTrue(_stats['min'].calculate(self.scatterContext) == 5) - self.assertTrue(_stats['max'].calculate(self.scatterContext) == 90) - self.assertTrue(_stats['minCoords'].calculate(self.scatterContext) == (0, 2)) - self.assertTrue(_stats['maxCoords'].calculate(self.scatterContext) == (50, 69)) - self.assertTrue(_stats['std'].calculate(self.scatterContext) == numpy.std(self.valuesScatterData)) - self.assertTrue(_stats['mean'].calculate(self.scatterContext) == numpy.mean(self.valuesScatterData)) - - comx = numpy.sum(self.xScatterData * self.valuesScatterData).astype(numpy.float32) / numpy.sum( - self.valuesScatterData).astype(numpy.float32) - comy = numpy.sum(self.yScatterData * self.valuesScatterData).astype(numpy.float32) / numpy.sum( - self.valuesScatterData).astype(numpy.float32) - self.assertTrue(numpy.all( - numpy.equal(_stats['com'].calculate(self.scatterContext), - (comx, comy))) - ) + self.assertEqual(_stats['min'].calculate(self.scatterContext), 5) + self.assertEqual(_stats['max'].calculate(self.scatterContext), 90) + self.assertEqual(_stats['minCoords'].calculate(self.scatterContext), (0, 2)) + self.assertEqual(_stats['maxCoords'].calculate(self.scatterContext), (50, 69)) + self.assertEqual(_stats['std'].calculate(self.scatterContext), numpy.std(self.valuesScatterData)) + self.assertEqual(_stats['mean'].calculate(self.scatterContext), numpy.mean(self.valuesScatterData)) + + data = self.valuesScatterData.astype(numpy.float64) + comx = numpy.sum(self.xScatterData * data) / numpy.sum(data) + comy = numpy.sum(self.yScatterData * data) / numpy.sum(data) + self.assertEqual(_stats['com'].calculate(self.scatterContext), + (comx, comy)) def testKindNotManagedByStat(self): """Make sure an exception is raised if we try to execute calculate @@ -227,21 +221,21 @@ class TestStats(TestCaseQt): item=self.plot1d.getCurve('curve0'), plot=self.plot1d, onlimits=True) - self.assertTrue(stat.calculate(curveContextOnLimits) == 2) + self.assertEqual(stat.calculate(curveContextOnLimits), 2) self.plot2d.getXAxis().setLimitsConstraints(minPos=32) imageContextOnLimits = stats._ImageContext( item=self.plot2d.getImage('test image'), plot=self.plot2d, onlimits=True) - self.assertTrue(stat.calculate(imageContextOnLimits) == 32) + self.assertEqual(stat.calculate(imageContextOnLimits), 32) self.scatterPlot.getXAxis().setLimitsConstraints(minPos=40) scatterContextOnLimits = stats._ScatterContext( item=self.scatterPlot.getScatter('scatter plot'), plot=self.scatterPlot, onlimits=True) - self.assertTrue(stat.calculate(scatterContextOnLimits) == 20) + self.assertEqual(stat.calculate(scatterContextOnLimits), 20) class TestStatsFormatter(TestCaseQt): @@ -267,15 +261,15 @@ class TestStatsFormatter(TestCaseQt): """Make sure a formatter with no formatter definition will return a simple cast to str""" emptyFormatter = statshandler.StatFormatter() - self.assertTrue( - emptyFormatter.format(self.stat.calculate(self.curveContext)) == '0.000') + self.assertEqual( + emptyFormatter.format(self.stat.calculate(self.curveContext)), '0.000') def testSettedFormatter(self): """Make sure a formatter with no formatter definition will return a simple cast to str""" formatter= statshandler.StatFormatter(formatter='{0:.3f}') - self.assertTrue( - formatter.format(self.stat.calculate(self.curveContext)) == '0.000') + self.assertEqual( + formatter.format(self.stat.calculate(self.curveContext)), '0.000') class TestStatsHandler(unittest.TestCase): @@ -309,9 +303,9 @@ class TestStatsHandler(unittest.TestCase): res = handler0.calculate(item=self.curveItem, plot=self.plot1d, onlimits=False) self.assertTrue('min' in res) - self.assertTrue(res['min'] == '0') + self.assertEqual(res['min'], '0') self.assertTrue('max' in res) - self.assertTrue(res['max'] == '19') + self.assertEqual(res['max'], '19') handler1 = statshandler.StatsHandler( ( @@ -323,9 +317,9 @@ class TestStatsHandler(unittest.TestCase): res = handler1.calculate(item=self.curveItem, plot=self.plot1d, onlimits=False) self.assertTrue('min' in res) - self.assertTrue(res['min'] == '0') + self.assertEqual(res['min'], '0') self.assertTrue('max' in res) - self.assertTrue(res['max'] == '19.000') + self.assertEqual(res['max'], '19.000') handler2 = statshandler.StatsHandler( ( @@ -336,9 +330,9 @@ class TestStatsHandler(unittest.TestCase): res = handler2.calculate(item=self.curveItem, plot=self.plot1d, onlimits=False) self.assertTrue('min' in res) - self.assertTrue(res['min'] == '0') + self.assertEqual(res['min'], '0') self.assertTrue('max' in res) - self.assertTrue(res['max'] == '19.000') + self.assertEqual(res['max'], '19.000') handler3 = statshandler.StatsHandler(( (('amin', numpy.argmin), statshandler.StatFormatter()), @@ -348,9 +342,9 @@ class TestStatsHandler(unittest.TestCase): res = handler3.calculate(item=self.curveItem, plot=self.plot1d, onlimits=False) self.assertTrue('amin' in res) - self.assertTrue(res['amin'] == '0.000') + self.assertEqual(res['amin'], '0.000') self.assertTrue('amax' in res) - self.assertTrue(res['amax'] == '19') + self.assertEqual(res['amax'], '19') with self.assertRaises(ValueError): statshandler.StatsHandler(('name')) @@ -395,47 +389,49 @@ class TestStatsWidgetWithCurves(TestCaseQt): def testInit(self): """Make sure all the curves are registred on initialization""" - self.assertTrue(self.widget.rowCount() is 3) + self.assertEqual(self.widget.rowCount(), 3) def testRemoveCurve(self): """Make sure the Curves stats take into account the curve removal from plot""" self.plot.removeCurve('curve2') - self.assertTrue(self.widget.rowCount() is 2) + self.assertEqual(self.widget.rowCount(), 2) for iRow in range(2): self.assertTrue(self.widget.item(iRow, 0).text() in ('curve0', 'curve1')) self.plot.removeCurve('curve0') - self.assertTrue(self.widget.rowCount() is 1) + self.assertEqual(self.widget.rowCount(), 1) self.plot.removeCurve('curve1') - self.assertTrue(self.widget.rowCount() is 0) + self.assertEqual(self.widget.rowCount(), 0) def testAddCurve(self): """Make sure the Curves stats take into account the add curve action""" self.plot.addCurve(legend='curve3', x=range(10), y=range(10)) - self.assertTrue(self.widget.rowCount() is 4) + self.assertEqual(self.widget.rowCount(), 4) - def testUpdateCurveFrmAddCurve(self): + def testUpdateCurveFromAddCurve(self): """Make sure the stats of the cuve will be removed after updating a curve""" self.plot.addCurve(legend='curve0', x=range(10), y=range(10)) - self.assertTrue(self.widget.rowCount() is 3) - itemMax = self.widget._getItem(name='max', legend='curve0', - kind='curve', indexTable=None) - self.assertTrue(itemMax.text() == '9') + self.qapp.processEvents() + self.assertEqual(self.widget.rowCount(), 3) + curve = self.plot._getItem(kind='curve', legend='curve0') + tableItems = self.widget._itemToTableItems(curve) + self.assertEqual(tableItems['max'].text(), '9') - def testUpdateCurveFrmCurveObj(self): + def testUpdateCurveFromCurveObj(self): self.plot.getCurve('curve0').setData(x=range(4), y=range(4)) - self.assertTrue(self.widget.rowCount() is 3) - itemMax = self.widget._getItem(name='max', legend='curve0', - kind='curve', indexTable=None) - self.assertTrue(itemMax.text() == '3') + self.qapp.processEvents() + self.assertEqual(self.widget.rowCount(), 3) + curve = self.plot._getItem(kind='curve', legend='curve0') + tableItems = self.widget._itemToTableItems(curve) + self.assertEqual(tableItems['max'].text(), '3') def testSetAnotherPlot(self): plot2 = Plot1D() plot2.addCurve(x=range(26), y=range(26), legend='new curve') self.widget.setPlot(plot2) - self.assertTrue(self.widget.rowCount() is 1) + self.assertEqual(self.widget.rowCount(), 1) self.qapp.processEvents() plot2.setAttribute(qt.Qt.WA_DeleteOnClose) plot2.close() @@ -444,12 +440,15 @@ class TestStatsWidgetWithCurves(TestCaseQt): class TestStatsWidgetWithImages(TestCaseQt): """Basic test for StatsWidget with images""" + + IMAGE_LEGEND = 'test image' + def setUp(self): TestCaseQt.setUp(self) self.plot = Plot2D() self.plot.addImage(data=numpy.arange(128*128).reshape(128, 128), - legend='test image', replace=False) + legend=self.IMAGE_LEGEND, replace=False) self.widget = StatsWidget.StatsTable(plot=self.plot) @@ -476,31 +475,30 @@ class TestStatsWidgetWithImages(TestCaseQt): TestCaseQt.tearDown(self) def test(self): - columnsIndex = self.widget._columns_index - itemLegend = self.widget._lgdAndKindToItems[('test image', 'image')]['legend'] - itemMin = self.widget.item(itemLegend.row(), columnsIndex['min']) - itemMax = self.widget.item(itemLegend.row(), columnsIndex['max']) - itemDelta = self.widget.item(itemLegend.row(), columnsIndex['delta']) - itemCoordsMin = self.widget.item(itemLegend.row(), - columnsIndex['coords min']) - itemCoordsMax = self.widget.item(itemLegend.row(), - columnsIndex['coords max']) - max = (128 * 128) - 1 - self.assertTrue(itemMin.text() == '0.000') - self.assertTrue(itemMax.text() == '{0:.3f}'.format(max)) - self.assertTrue(itemDelta.text() == '{0:.3f}'.format(max)) - self.assertTrue(itemCoordsMin.text() == '0.0, 0.0') - self.assertTrue(itemCoordsMax.text() == '127.0, 127.0') + image = self.plot._getItem( + kind='image', legend=self.IMAGE_LEGEND) + tableItems = self.widget._itemToTableItems(image) + + maxText = '{0:.3f}'.format((128 * 128) - 1) + self.assertEqual(tableItems['legend'].text(), self.IMAGE_LEGEND) + self.assertEqual(tableItems['min'].text(), '0.000') + self.assertEqual(tableItems['max'].text(), maxText) + self.assertEqual(tableItems['delta'].text(), maxText) + self.assertEqual(tableItems['coords min'].text(), '0.0, 0.0') + self.assertEqual(tableItems['coords max'].text(), '127.0, 127.0') class TestStatsWidgetWithScatters(TestCaseQt): + + SCATTER_LEGEND = 'scatter plot' + def setUp(self): TestCaseQt.setUp(self) self.scatterPlot = Plot2D() self.scatterPlot.addScatter([0, 1, 2, 20, 50, 60], [2, 3, 4, 26, 69, 6], [5, 6, 7, 10, 90, 20], - legend='scatter plot') + legend=self.SCATTER_LEGEND) self.widget = StatsWidget.StatsTable(plot=self.scatterPlot) mystats = statshandler.StatsHandler(( @@ -526,33 +524,89 @@ class TestStatsWidgetWithScatters(TestCaseQt): TestCaseQt.tearDown(self) def testStats(self): - columnsIndex = self.widget._columns_index - itemLegend = self.widget._lgdAndKindToItems[('scatter plot', 'scatter')]['legend'] - itemMin = self.widget.item(itemLegend.row(), columnsIndex['min']) - itemMax = self.widget.item(itemLegend.row(), columnsIndex['max']) - itemDelta = self.widget.item(itemLegend.row(), columnsIndex['delta']) - itemCoordsMin = self.widget.item(itemLegend.row(), - columnsIndex['coords min']) - itemCoordsMax = self.widget.item(itemLegend.row(), - columnsIndex['coords max']) - self.assertTrue(itemMin.text() == '5') - self.assertTrue(itemMax.text() == '90') - self.assertTrue(itemDelta.text() == '85') - self.assertTrue(itemCoordsMin.text() == '0, 2') - self.assertTrue(itemCoordsMax.text() == '50, 69') + scatter = self.scatterPlot._getItem( + kind='scatter', legend=self.SCATTER_LEGEND) + tableItems = self.widget._itemToTableItems(scatter) + self.assertEqual(tableItems['legend'].text(), self.SCATTER_LEGEND) + self.assertEqual(tableItems['min'].text(), '5') + self.assertEqual(tableItems['coords min'].text(), '0, 2') + self.assertEqual(tableItems['max'].text(), '90') + self.assertEqual(tableItems['coords max'].text(), '50, 69') + self.assertEqual(tableItems['delta'].text(), '85') class TestEmptyStatsWidget(TestCaseQt): def test(self): widget = StatsWidget.StatsWidget() widget.show() + self.qWaitForWindowExposed(widget) + + +# skip unit test for pyqt4 because there is some unrealised widget without +# apparent reason +@unittest.skipIf(qt.qVersion().split('.')[0] == '4', reason='PyQt4 not tested') +class TestLineWidget(TestCaseQt): + """Some test for the StatsLineWidget.""" + def setUp(self): + TestCaseQt.setUp(self) + + mystats = statshandler.StatsHandler(( + (stats.StatMin(), statshandler.StatFormatter()), + )) + + self.plot = Plot1D() + self.plot.show() + x = range(20) + y = range(20) + self.plot.addCurve(x, y, legend='curve0') + y = range(12, 32) + self.plot.addCurve(x, y, legend='curve1') + y = range(-2, 18) + self.plot.addCurve(x, y, legend='curve2') + self.widget = StatsWidget.BasicGridStatsWidget(plot=self.plot, + kind='curve', + stats=mystats) + + def tearDown(self): + self.qapp.processEvents() + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + self.widget.setPlot(None) + self.widget._statQlineEdit.clear() + self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) + self.widget.close() + self.widget = None + self.plot = None + TestCaseQt.tearDown(self) + + def test(self): + self.widget.setStatsOnVisibleData(False) + self.qapp.processEvents() + self.plot.setActiveCurve(legend='curve0') + self.assertTrue(self.widget._statQlineEdit['min'].text() == '0.000') + self.plot.setActiveCurve(legend='curve1') + self.assertTrue(self.widget._statQlineEdit['min'].text() == '12.000') + self.plot.getXAxis().setLimitsConstraints(minPos=2, maxPos=5) + self.widget.setStatsOnVisibleData(True) + self.qapp.processEvents() + self.assertTrue(self.widget._statQlineEdit['min'].text() == '14.000') + self.plot.setActiveCurve(None) + self.assertTrue(self.plot.getActiveCurve() is None) + self.widget.setStatsOnVisibleData(False) + self.qapp.processEvents() + self.assertFalse(self.widget._statQlineEdit['min'].text() == '14.000') + self.widget.setKind('image') + self.plot.addImage(numpy.arange(100*100).reshape(100, 100) + 0.312) + self.qapp.processEvents() + self.assertTrue(self.widget._statQlineEdit['min'].text() == '0.312') def suite(): test_suite = unittest.TestSuite() for TestClass in (TestStats, TestStatsHandler, TestStatsWidgetWithScatters, TestStatsWidgetWithImages, TestStatsWidgetWithCurves, - TestStatsFormatter, TestEmptyStatsWidget): + TestStatsFormatter, TestEmptyStatsWidget, + TestLineWidget): test_suite.addTest( unittest.defaultTestLoader.loadTestsFromTestCase(TestClass)) return test_suite diff --git a/silx/gui/plot/test/testUtilsAxis.py b/silx/gui/plot/test/testUtilsAxis.py index 016fafe..64373b8 100644 --- a/silx/gui/plot/test/testUtilsAxis.py +++ b/silx/gui/plot/test/testUtilsAxis.py @@ -26,7 +26,7 @@ __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "14/02/2018" +__date__ = "20/11/2018" import unittest @@ -155,6 +155,53 @@ class TestAxisSync(TestCaseQt): self.assertEqual(self.plot2.getYAxis().isInverted(), True) self.assertEqual(self.plot3.getYAxis().isInverted(), True) + def testSyncCenter(self): + """Test direction change""" + # Not the same scale + self.plot1.getXAxis().setLimits(0, 200) + self.plot2.getXAxis().setLimits(0, 20) + self.plot3.getXAxis().setLimits(0, 2) + _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()], + syncLimits=False, syncCenter=True) + + self.assertEqual(self.plot1.getXAxis().getLimits(), (0, 200)) + self.assertEqual(self.plot2.getXAxis().getLimits(), (100 - 10, 100 + 10)) + self.assertEqual(self.plot3.getXAxis().getLimits(), (100 - 1, 100 + 1)) + + def testSyncCenterAndZoom(self): + """Test direction change""" + # Not the same scale + self.plot1.getXAxis().setLimits(0, 200) + self.plot2.getXAxis().setLimits(0, 20) + self.plot3.getXAxis().setLimits(0, 2) + _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()], + syncLimits=False, syncCenter=True, syncZoom=True) + + # Supposing all the plots use the same size + self.assertEqual(self.plot1.getXAxis().getLimits(), (0, 200)) + self.assertEqual(self.plot2.getXAxis().getLimits(), (0, 200)) + self.assertEqual(self.plot3.getXAxis().getLimits(), (0, 200)) + + def testAddAxis(self): + """Test synchronization after construction""" + sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis()]) + sync.addAxis(self.plot3.getXAxis()) + + self.plot1.getXAxis().setLimits(10, 500) + self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500)) + self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500)) + self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500)) + + def testRemoveAxis(self): + """Test synchronization after construction""" + sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]) + sync.removeAxis(self.plot3.getXAxis()) + + self.plot1.getXAxis().setLimits(10, 500) + self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500)) + self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500)) + self.assertNotEqual(self.plot3.getXAxis().getLimits(), (10, 500)) + def suite(): test_suite = unittest.TestSuite() diff --git a/silx/gui/plot/tools/roi.py b/silx/gui/plot/tools/roi.py index d58c041..98295ba 100644 --- a/silx/gui/plot/tools/roi.py +++ b/silx/gui/plot/tools/roi.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2018 European Synchrotron Radiation Facility +# Copyright (c) 2018-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -31,6 +31,7 @@ __date__ = "28/06/2018" import collections +import enum import functools import logging import time @@ -38,7 +39,6 @@ import weakref import numpy -from ....third_party import enum from ....utils.weakref import WeakMethodProxy from ... import qt, icons from .. import PlotWidget @@ -806,11 +806,17 @@ class RegionOfInterestTableWidget(qt.QTableWidget): self.itemChanged.connect(self.__itemChanged) - @staticmethod - def __itemChanged(item): + def __itemChanged(self, item): """Handle item updates""" column = item.column() - roi = item.data(qt.Qt.UserRole) + index = item.data(qt.Qt.UserRole) + + if index is not None: + manager = self.getRegionOfInterestManager() + roi = manager.getRois()[index] + else: + roi = None + if column == 0: roi.setLabel(item.text()) elif column == 1: @@ -882,13 +888,13 @@ class RegionOfInterestTableWidget(qt.QTableWidget): label = roi.getLabel() item = qt.QTableWidgetItem(label) item.setFlags(baseFlags | qt.Qt.ItemIsEditable) - item.setData(qt.Qt.UserRole, roi) + item.setData(qt.Qt.UserRole, index) self.setItem(index, 0, item) # Editable item = qt.QTableWidgetItem() item.setFlags(baseFlags | qt.Qt.ItemIsUserCheckable) - item.setData(qt.Qt.UserRole, roi) + item.setData(qt.Qt.UserRole, index) item.setCheckState( qt.Qt.Checked if roi.isEditable() else qt.Qt.Unchecked) self.setItem(index, 1, item) diff --git a/silx/gui/plot/tools/test/testScatterProfileToolBar.py b/silx/gui/plot/tools/test/testScatterProfileToolBar.py index b99cac7..0f4b668 100644 --- a/silx/gui/plot/tools/test/testScatterProfileToolBar.py +++ b/silx/gui/plot/tools/test/testScatterProfileToolBar.py @@ -97,7 +97,7 @@ class TestScatterProfileToolBar(TestCaseQt, ParametricTestCase): self.profile._getRoiManager().addRoi(roi) # Wait for async interpolator init - for _ in range(10): + for _ in range(20): self.qWait(200) if not self.profile.hasPendingOperations(): break diff --git a/silx/gui/plot/utils/axis.py b/silx/gui/plot/utils/axis.py index bd19996..693e8eb 100644 --- a/silx/gui/plot/utils/axis.py +++ b/silx/gui/plot/utils/axis.py @@ -27,13 +27,14 @@ __authors__ = ["V. Valls"] __license__ = "MIT" -__date__ = "23/02/2018" +__date__ = "20/11/2018" import functools import logging from contextlib import contextmanager import weakref import silx.utils.weakref as silxWeakref +from silx.gui.plot.items.axis import Axis, XAxis, YAxis try: from ...qt.inspect import isValid as _isQObjectValid @@ -61,7 +62,14 @@ class SyncAxes(object): .. versionadded:: 0.6 """ - def __init__(self, axes, syncLimits=True, syncScale=True, syncDirection=True): + def __init__(self, axes, + syncLimits=True, + syncScale=True, + syncDirection=True, + syncCenter=False, + syncZoom=False, + filterHiddenPlots=False + ): """ Constructor @@ -69,17 +77,34 @@ class SyncAxes(object): :param bool syncLimits: Synchronize axes limits :param bool syncScale: Synchronize axes scale :param bool syncDirection: Synchronize axes direction + :param bool syncCenter: Synchronize the center of the axes in the center + of the plots + :param bool syncZoom: Synchronize the zoom of the plot + :param bool filterHiddenPlots: True to avoid updating hidden plots. + Default: False. """ object.__init__(self) + + def implies(x, y): return bool(y ** x) + + assert(implies(syncZoom, not syncLimits)) + assert(implies(syncCenter, not syncLimits)) + assert(implies(syncLimits, not syncCenter)) + assert(implies(syncLimits, not syncZoom)) + + self.__filterHiddenPlots = filterHiddenPlots self.__locked = False self.__axisRefs = [] self.__syncLimits = syncLimits self.__syncScale = syncScale self.__syncDirection = syncDirection + self.__syncCenter = syncCenter + self.__syncZoom = syncZoom self.__callbacks = None + self.__lastMainAxis = None for axis in axes: - self.__axisRefs.append(weakref.ref(axis)) + self.addAxis(axis) self.start() @@ -90,47 +115,131 @@ class SyncAxes(object): After that, any changes to any axes will be used to synchronize other axes. """ - if self.__callbacks is not None: + if self.isSynchronizing(): raise RuntimeError("Axes already synchronized") self.__callbacks = {} axes = self.__getAxes() - if len(axes) == 0: - raise RuntimeError('No axis to synchronize') # register callback for further sync for axis in axes: - refAxis = weakref.ref(axis) - callbacks = [] - if self.__syncLimits: - # the weakref is needed to be able ignore self references - callback = silxWeakref.WeakMethodProxy(self.__axisLimitsChanged) - callback = functools.partial(callback, refAxis) - sig = axis.sigLimitsChanged - sig.connect(callback) - callbacks.append(("sigLimitsChanged", callback)) - if self.__syncScale: - # the weakref is needed to be able ignore self references - callback = silxWeakref.WeakMethodProxy(self.__axisScaleChanged) - callback = functools.partial(callback, refAxis) - sig = axis.sigScaleChanged - sig.connect(callback) - callbacks.append(("sigScaleChanged", callback)) - if self.__syncDirection: - # the weakref is needed to be able ignore self references - callback = silxWeakref.WeakMethodProxy(self.__axisInvertedChanged) - callback = functools.partial(callback, refAxis) - sig = axis.sigInvertedChanged - sig.connect(callback) - callbacks.append(("sigInvertedChanged", callback)) - - self.__callbacks[refAxis] = callbacks + self.__connectAxes(axis) + self.synchronize() + + def isSynchronizing(self): + """Returns true if events are connected to the axes to synchronize them + all together + + :rtype: bool + """ + return self.__callbacks is not None + + def __connectAxes(self, axis): + refAxis = weakref.ref(axis) + callbacks = [] + if self.__syncLimits: + # the weakref is needed to be able ignore self references + callback = silxWeakref.WeakMethodProxy(self.__axisLimitsChanged) + callback = functools.partial(callback, refAxis) + sig = axis.sigLimitsChanged + sig.connect(callback) + callbacks.append(("sigLimitsChanged", callback)) + elif self.__syncCenter and self.__syncZoom: + # the weakref is needed to be able ignore self references + callback = silxWeakref.WeakMethodProxy(self.__axisCenterAndZoomChanged) + callback = functools.partial(callback, refAxis) + sig = axis.sigLimitsChanged + sig.connect(callback) + callbacks.append(("sigLimitsChanged", callback)) + elif self.__syncZoom: + raise NotImplementedError() + elif self.__syncCenter: + # the weakref is needed to be able ignore self references + callback = silxWeakref.WeakMethodProxy(self.__axisCenterChanged) + callback = functools.partial(callback, refAxis) + sig = axis.sigLimitsChanged + sig.connect(callback) + callbacks.append(("sigLimitsChanged", callback)) + if self.__syncScale: + # the weakref is needed to be able ignore self references + callback = silxWeakref.WeakMethodProxy(self.__axisScaleChanged) + callback = functools.partial(callback, refAxis) + sig = axis.sigScaleChanged + sig.connect(callback) + callbacks.append(("sigScaleChanged", callback)) + if self.__syncDirection: + # the weakref is needed to be able ignore self references + callback = silxWeakref.WeakMethodProxy(self.__axisInvertedChanged) + callback = functools.partial(callback, refAxis) + sig = axis.sigInvertedChanged + sig.connect(callback) + callbacks.append(("sigInvertedChanged", callback)) + + if self.__filterHiddenPlots: + # the weakref is needed to be able ignore self references + callback = silxWeakref.WeakMethodProxy(self.__axisVisibilityChanged) + callback = functools.partial(callback, refAxis) + plot = axis._getPlot() + plot.sigVisibilityChanged.connect(callback) + callbacks.append(("sigVisibilityChanged", callback)) + + self.__callbacks[refAxis] = callbacks + def __disconnectAxes(self, axis): + if axis is not None and _isQObjectValid(axis): + ref = weakref.ref(axis) + callbacks = self.__callbacks.pop(ref) + for sigName, callback in callbacks: + if sigName == "sigVisibilityChanged": + obj = axis._getPlot() + else: + obj = axis + if obj is not None: + sig = getattr(obj, sigName) + sig.disconnect(callback) + + def addAxis(self, axis): + """Add a new axes to synchronize. + + :param ~silx.gui.plot.items.Axis axis: The axis to synchronize + """ + self.__axisRefs.append(weakref.ref(axis)) + if self.isSynchronizing(): + self.__connectAxes(axis) + # This could be done faster as only this axis have to be fixed + self.synchronize() + + def removeAxis(self, axis): + """Remove an axis from the synchronized axes. + + :param ~silx.gui.plot.items.Axis axis: The axis to remove + """ + ref = weakref.ref(axis) + self.__axisRefs.remove(ref) + if self.isSynchronizing(): + self.__disconnectAxes(axis) + + def synchronize(self, mainAxis=None): + """Synchronize programatically all the axes. + + :param ~silx.gui.plot.items.Axis mainAxis: + The axis to take as reference (Default: the first axis). + """ # sync the current state - mainAxis = axes[0] + axes = self.__getAxes() + if len(axes) == 0: + return + + if mainAxis is None: + mainAxis = axes[0] + refMainAxis = weakref.ref(mainAxis) if self.__syncLimits: self.__axisLimitsChanged(refMainAxis, *mainAxis.getLimits()) + elif self.__syncCenter and self.__syncZoom: + self.__axisCenterAndZoomChanged(refMainAxis, *mainAxis.getLimits()) + elif self.__syncCenter: + self.__axisCenterChanged(refMainAxis, *mainAxis.getLimits()) if self.__syncScale: self.__axisScaleChanged(refMainAxis, mainAxis.getScale()) if self.__syncDirection: @@ -138,14 +247,11 @@ class SyncAxes(object): def stop(self): """Stop the synchronization of the axes""" - if self.__callbacks is None: + if not self.isSynchronizing(): raise RuntimeError("Axes not synchronized") - for ref, callbacks in self.__callbacks.items(): + for ref in list(self.__callbacks.keys()): axis = ref() - if axis is not None and _isQObjectValid(axis): - for sigName, callback in callbacks: - sig = getattr(axis, sigName) - sig.disconnect(callback) + self.__disconnectAxes(axis) self.__callbacks = None def __del__(self): @@ -168,32 +274,130 @@ class SyncAxes(object): yield self.__locked = False - def __otherAxes(self, changedAxis): + def __axesToUpdate(self, changedAxis): for axis in self.__getAxes(): if axis is changedAxis: continue + if self.__filterHiddenPlots: + plot = axis._getPlot() + if not plot.isVisible(): + continue yield axis + def __axisVisibilityChanged(self, changedAxis, isVisible): + if not isVisible: + return + if self.__locked: + return + changedAxis = changedAxis() + if self.__lastMainAxis is None: + self.__lastMainAxis = self.__axisRefs[0] + mainAxis = self.__lastMainAxis + mainAxis = mainAxis() + self.synchronize(mainAxis=mainAxis) + # force back the main axis + self.__lastMainAxis = weakref.ref(mainAxis) + + def __getAxesCenter(self, axis, vmin, vmax): + """Returns the value displayed in the center of this axis range. + + :rtype: float + """ + scale = axis.getScale() + if scale == Axis.LINEAR: + center = (vmin + vmax) * 0.5 + else: + raise NotImplementedError("Log scale not implemented") + return center + + def __getRangeInPixel(self, axis): + """Returns the size of the axis in pixel""" + bounds = axis._getPlot().getPlotBoundsInPixels() + # bounds: left, top, width, height + if isinstance(axis, XAxis): + return bounds[2] + elif isinstance(axis, YAxis): + return bounds[3] + else: + assert(False) + + def __getLimitsFromCenter(self, axis, pos, pixelSize=None): + """Returns the limits to apply to this axis to move the `pos` into the + center of this axis. + + :param Axis axis: + :param float pos: Position in the center of the computed limits + :param Union[None,float] pixelSize: Pixel size to apply to compute the + limits. If `None` the current pixel size is applyed. + """ + scale = axis.getScale() + if scale == Axis.LINEAR: + if pixelSize is None: + # Use the current pixel size of the axis + limits = axis.getLimits() + valueRange = limits[0] - limits[1] + a = pos - valueRange * 0.5 + b = pos + valueRange * 0.5 + else: + pixelRange = self.__getRangeInPixel(axis) + a = pos - pixelRange * 0.5 * pixelSize + b = pos + pixelRange * 0.5 * pixelSize + + else: + raise NotImplementedError("Log scale not implemented") + if a > b: + return b, a + return a, b + def __axisLimitsChanged(self, changedAxis, vmin, vmax): if self.__locked: return + self.__lastMainAxis = changedAxis + changedAxis = changedAxis() + with self.__inhibitSignals(): + for axis in self.__axesToUpdate(changedAxis): + axis.setLimits(vmin, vmax) + + def __axisCenterAndZoomChanged(self, changedAxis, vmin, vmax): + if self.__locked: + return + self.__lastMainAxis = changedAxis + changedAxis = changedAxis() + with self.__inhibitSignals(): + center = self.__getAxesCenter(changedAxis, vmin, vmax) + pixelRange = self.__getRangeInPixel(changedAxis) + if pixelRange == 0: + return + pixelSize = (vmax - vmin) / pixelRange + for axis in self.__axesToUpdate(changedAxis): + vmin, vmax = self.__getLimitsFromCenter(axis, center, pixelSize) + axis.setLimits(vmin, vmax) + + def __axisCenterChanged(self, changedAxis, vmin, vmax): + if self.__locked: + return + self.__lastMainAxis = changedAxis changedAxis = changedAxis() with self.__inhibitSignals(): - for axis in self.__otherAxes(changedAxis): + center = self.__getAxesCenter(changedAxis, vmin, vmax) + for axis in self.__axesToUpdate(changedAxis): + vmin, vmax = self.__getLimitsFromCenter(axis, center) axis.setLimits(vmin, vmax) def __axisScaleChanged(self, changedAxis, scale): if self.__locked: return + self.__lastMainAxis = changedAxis changedAxis = changedAxis() with self.__inhibitSignals(): - for axis in self.__otherAxes(changedAxis): + for axis in self.__axesToUpdate(changedAxis): axis.setScale(scale) def __axisInvertedChanged(self, changedAxis, isInverted): if self.__locked: return + self.__lastMainAxis = changedAxis changedAxis = changedAxis() with self.__inhibitSignals(): - for axis in self.__otherAxes(changedAxis): + for axis in self.__axesToUpdate(changedAxis): axis.setInverted(isInverted) |