diff options
Diffstat (limited to 'silx/gui/data/NumpyAxesSelector.py')
-rw-r--r-- | silx/gui/data/NumpyAxesSelector.py | 236 |
1 files changed, 168 insertions, 68 deletions
diff --git a/silx/gui/data/NumpyAxesSelector.py b/silx/gui/data/NumpyAxesSelector.py index 4530aa9..e6da0d4 100644 --- a/silx/gui/data/NumpyAxesSelector.py +++ b/silx/gui/data/NumpyAxesSelector.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -31,13 +31,18 @@ __authors__ = ["V. Valls"] __license__ = "MIT" __date__ = "29/01/2018" +import logging import numpy import functools from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser from silx.gui import qt +from silx.gui.utils import blockSignals import silx.utils.weakref +_logger = logging.getLogger(__name__) + + class _Axis(qt.QWidget): """Widget displaying an axis. @@ -110,6 +115,8 @@ class _Axis(qt.QWidget): if axisName == "" and self.__axes.count() == 0: self.__axes.setCurrentIndex(-1) self.__updateSliderVisibility() + return + for index in range(self.__axes.count()): name = self.__axes.itemData(index) if name == axisName: @@ -121,7 +128,7 @@ class _Axis(qt.QWidget): def axisName(self): """Returns the selected axis name. - If no names are selected, an empty string is retruned. + If no name is selected, an empty string is returned. :rtype: str """ @@ -136,11 +143,11 @@ class _Axis(qt.QWidget): :param List[str] axesNames: List of available names """ self.__axes.clear() - previous = self.__axes.blockSignals(True) - self.__axes.addItem(" ", "") - for axis in axesNames: - self.__axes.addItem(axis, axis) - self.__axes.blockSignals(previous) + with blockSignals(self.__axes): + self.__axes.addItem(" ", "") + for axis in axesNames: + self.__axes.addItem(axis, axis) + self.__updateSliderVisibility() def setCustomAxis(self, axesNames): @@ -168,12 +175,19 @@ class _Axis(qt.QWidget): self.__slider.setVisible(isVisible) def value(self): - """Returns the current selected position in the axis. + """Returns the currently selected position in the axis. :rtype: int """ return self.__slider.value() + def setValue(self, value): + """Set the currently selected position in the axis. + + :param int value: + """ + self.__slider.setValue(value) + def __sliderValueChanged(self, value): """Called when the selected position in the axis change. @@ -183,18 +197,14 @@ class _Axis(qt.QWidget): def setNamedAxisSelectorVisibility(self, visible): """Hide or show the named axis combobox. - If both the selector and the slider are hidden, - hide the entire widget. + + If both the selector and the slider are hidden, hide the entire widget. :param visible: boolean """ self.__axes.setVisible(visible) name = self.axisName() - - if not visible and name != "": - self.setVisible(False) - else: - self.setVisible(True) + self.setVisible(visible or name == "") class NumpyAxesSelector(qt.QWidget): @@ -236,7 +246,6 @@ class NumpyAxesSelector(qt.QWidget): self.__data = None self.__selectedData = None - self.__selection = tuple() self.__axis = [] self.__axisNames = [] self.__customAxisNames = set([]) @@ -268,13 +277,12 @@ class NumpyAxesSelector(qt.QWidget): if delta < 0: delta = 0 for index, axis in enumerate(self.__axis): - previous = axis.blockSignals(True) - axis.setAxisNames(self.__axisNames) - if index >= delta and index - delta < len(self.__axisNames): - axis.setAxisName(self.__axisNames[index - delta]) - else: - axis.setAxisName("") - axis.blockSignals(previous) + with blockSignals(axis): + axis.setAxisNames(self.__axisNames) + if index >= delta and index - delta < len(self.__axisNames): + axis.setAxisName(self.__axisNames[index - delta]) + else: + axis.setAxisName("") self.__updateSelectedData() def setCustomAxis(self, axesNames): @@ -372,9 +380,8 @@ class NumpyAxesSelector(qt.QWidget): # If there is no other solution we set the name at the same place axisChanged = False availableWidget = axis - previous = availableWidget.blockSignals(True) - availableWidget.setAxisName(missingName) - availableWidget.blockSignals(previous) + with blockSignals(availableWidget): + availableWidget.setAxisName(missingName) else: # there is a duplicated name somewhere # we swap it with the missing name or with nothing @@ -387,9 +394,8 @@ class NumpyAxesSelector(qt.QWidget): break if missingName is None: missingName = "" - previous = dupWidget.blockSignals(True) - dupWidget.setAxisName(missingName) - dupWidget.blockSignals(previous) + with blockSignals(dupWidget): + dupWidget.setAxisName(missingName) if self.__data is None: return @@ -402,70 +408,164 @@ class NumpyAxesSelector(qt.QWidget): It fires a `selectionChanged` event. """ - if self.__data is None: + permutation = self.permutation() + + if self.__data is None or permutation is None: + # No data or not all the expected axes are there if self.__selectedData is not None: self.__selectedData = None - self.__selection = tuple() self.selectionChanged.emit() return - selection = [] - axisNames = [] - for slider in self.__axis: - name = slider.axisName() - if name == "": - selection.append(slider.value()) - else: - selection.append(slice(None)) - axisNames.append(name) - self.__selection = tuple(selection) # get a view with few fixed dimensions # with a h5py dataset, it create a copy # TODO we can reuse the same memory in case of a copy - view = self.__data[self.__selection] - - if set(self.__axisNames) - set(axisNames) != set([]): - # Not all the expected axis are there - if self.__selectedData is not None: - self.__selectedData = None - self.__selection = tuple() - self.selectionChanged.emit() - return - - # order axis as expected - source = [] - destination = [] - order = [] - for index, name in enumerate(self.__axisNames): - destination.append(index) - source.append(axisNames.index(name)) - for _, s in sorted(zip(destination, source)): - order.append(s) - view = numpy.transpose(view, order) - - self.__selectedData = view + self.__selectedData = numpy.transpose(self.__data[self.selection()], permutation) self.selectionChanged.emit() def data(self): """Returns the input data. - :rtype: numpy.ndarray + :rtype: Union[numpy.ndarray,None] """ - return self.__data + if self.__data is None: + return None + else: + return numpy.array(self.__data, copy=False) def selectedData(self): """Returns the output data. - :rtype: numpy.ndarray + This is equivalent to:: + + numpy.transpose(self.data()[self.selection()], self.permutation()) + + :rtype: Union[numpy.ndarray,None] """ - return self.__selectedData + if self.__selectedData is None: + return None + else: + return numpy.array(self.__selectedData, copy=False) + + def permutation(self): + """Returns the axes permutation to convert data subset to selected data. + + If permutation cannot be computer, it returns None. + + :rtype: Union[List[int],None] + """ + if self.__data is None: + return None + else: + indices = [] + for name in self.__axisNames: + index = 0 + for axis in self.__axis: + if axis.axisName() == name: + indices.append(index) + break + if axis.axisName() != "": + index += 1 + else: + _logger.warning("No axis corresponding to: %s", name) + return None + return tuple(indices) def selection(self): """Returns the selection tuple used to slice the data. :rtype: tuple """ - return self.__selection + if self.__data is None: + return tuple() + else: + return tuple([axis.value() if axis.axisName() == "" else slice(None) + for axis in self.__axis]) + + def setSelection(self, selection, permutation=None): + """Set the selection along each dimension. + + tuple returned by :meth:`selection` can be provided as input, + provided that it is for the same the number of axes and + the same number of dimensions of the data. + + :param List[Union[int,slice,None]] selection: + The selection tuple with as one element for each dimension of the data. + If an element is None, then the whole dimension is selected. + :param Union[List[int],None] permutation: + The data axes indices to transpose. + If not given, no permutation is applied + :raise ValueError: + When the selection does not match current data shape and number of axes. + """ + data_shape = self.__data.shape if self.__data is not None else () + + # Check selection + if len(selection) != len(data_shape): + raise ValueError( + "Selection length (%d) and data ndim (%d) mismatch" % + (len(selection), len(data_shape))) + + # Check selection type + selectedDataNDim = 0 + for element, size in zip(selection, data_shape): + if isinstance(element, int): + if not 0 <= element < size: + raise ValueError( + "Selected index (%d) outside data dimension range [0-%d]" % + (element, size)) + elif element is None or element == slice(None): + selectedDataNDim += 1 + else: + raise ValueError("Unsupported element in selection: %s" % element) + + ndim = len(self.__axisNames) + if selectedDataNDim != ndim: + raise ValueError( + "Selection dimensions (%d) and number of axes (%d) mismatch" % + (selectedDataNDim, ndim)) + + # check permutation + if permutation is None: + permutation = tuple(range(ndim)) + + if set(permutation) != set(range(ndim)): + raise ValueError( + "Error in provided permutation: " + "Wrong size, elements out of range or duplicates") + + inversePermutation = numpy.argsort(permutation) + + axisNameChanged = False + customValueChanged = [] + with blockSignals(*self.__axis): + index = 0 + for element, axis in zip(selection, self.__axis): + if isinstance(element, int): + name = "" + else: + name = self.__axisNames[inversePermutation[index]] + index += 1 + + if axis.axisName() != name: + axis.setAxisName(name) + axisNameChanged = True + + for element, axis in zip(selection, self.__axis): + value = element if isinstance(element, int) else 0 + if axis.value() != value: + axis.setValue(value) + + name = axis.axisName() + if name in self.__customAxisNames: + customValueChanged.append((name, value)) + + # Send signals that where disabled + if axisNameChanged: + self.selectedAxisChanged.emit() + for name, value in customValueChanged: + self.customAxisChanged.emit(name, value) + self.__updateSelectedData() def setNamedAxesSelectorVisibility(self, visible): """Show or hide the combo-boxes allowing to map the plot axes |