summaryrefslogtreecommitdiff
path: root/silx/gui/data/NumpyAxesSelector.py
diff options
context:
space:
mode:
Diffstat (limited to 'silx/gui/data/NumpyAxesSelector.py')
-rw-r--r--silx/gui/data/NumpyAxesSelector.py236
1 files changed, 168 insertions, 68 deletions
diff --git a/silx/gui/data/NumpyAxesSelector.py b/silx/gui/data/NumpyAxesSelector.py
index 4530aa9..e6da0d4 100644
--- a/silx/gui/data/NumpyAxesSelector.py
+++ b/silx/gui/data/NumpyAxesSelector.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -31,13 +31,18 @@ __authors__ = ["V. Valls"]
__license__ = "MIT"
__date__ = "29/01/2018"
+import logging
import numpy
import functools
from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser
from silx.gui import qt
+from silx.gui.utils import blockSignals
import silx.utils.weakref
+_logger = logging.getLogger(__name__)
+
+
class _Axis(qt.QWidget):
"""Widget displaying an axis.
@@ -110,6 +115,8 @@ class _Axis(qt.QWidget):
if axisName == "" and self.__axes.count() == 0:
self.__axes.setCurrentIndex(-1)
self.__updateSliderVisibility()
+ return
+
for index in range(self.__axes.count()):
name = self.__axes.itemData(index)
if name == axisName:
@@ -121,7 +128,7 @@ class _Axis(qt.QWidget):
def axisName(self):
"""Returns the selected axis name.
- If no names are selected, an empty string is retruned.
+ If no name is selected, an empty string is returned.
:rtype: str
"""
@@ -136,11 +143,11 @@ class _Axis(qt.QWidget):
:param List[str] axesNames: List of available names
"""
self.__axes.clear()
- previous = self.__axes.blockSignals(True)
- self.__axes.addItem(" ", "")
- for axis in axesNames:
- self.__axes.addItem(axis, axis)
- self.__axes.blockSignals(previous)
+ with blockSignals(self.__axes):
+ self.__axes.addItem(" ", "")
+ for axis in axesNames:
+ self.__axes.addItem(axis, axis)
+
self.__updateSliderVisibility()
def setCustomAxis(self, axesNames):
@@ -168,12 +175,19 @@ class _Axis(qt.QWidget):
self.__slider.setVisible(isVisible)
def value(self):
- """Returns the current selected position in the axis.
+ """Returns the currently selected position in the axis.
:rtype: int
"""
return self.__slider.value()
+ def setValue(self, value):
+ """Set the currently selected position in the axis.
+
+ :param int value:
+ """
+ self.__slider.setValue(value)
+
def __sliderValueChanged(self, value):
"""Called when the selected position in the axis change.
@@ -183,18 +197,14 @@ class _Axis(qt.QWidget):
def setNamedAxisSelectorVisibility(self, visible):
"""Hide or show the named axis combobox.
- If both the selector and the slider are hidden,
- hide the entire widget.
+
+ If both the selector and the slider are hidden, hide the entire widget.
:param visible: boolean
"""
self.__axes.setVisible(visible)
name = self.axisName()
-
- if not visible and name != "":
- self.setVisible(False)
- else:
- self.setVisible(True)
+ self.setVisible(visible or name == "")
class NumpyAxesSelector(qt.QWidget):
@@ -236,7 +246,6 @@ class NumpyAxesSelector(qt.QWidget):
self.__data = None
self.__selectedData = None
- self.__selection = tuple()
self.__axis = []
self.__axisNames = []
self.__customAxisNames = set([])
@@ -268,13 +277,12 @@ class NumpyAxesSelector(qt.QWidget):
if delta < 0:
delta = 0
for index, axis in enumerate(self.__axis):
- previous = axis.blockSignals(True)
- axis.setAxisNames(self.__axisNames)
- if index >= delta and index - delta < len(self.__axisNames):
- axis.setAxisName(self.__axisNames[index - delta])
- else:
- axis.setAxisName("")
- axis.blockSignals(previous)
+ with blockSignals(axis):
+ axis.setAxisNames(self.__axisNames)
+ if index >= delta and index - delta < len(self.__axisNames):
+ axis.setAxisName(self.__axisNames[index - delta])
+ else:
+ axis.setAxisName("")
self.__updateSelectedData()
def setCustomAxis(self, axesNames):
@@ -372,9 +380,8 @@ class NumpyAxesSelector(qt.QWidget):
# If there is no other solution we set the name at the same place
axisChanged = False
availableWidget = axis
- previous = availableWidget.blockSignals(True)
- availableWidget.setAxisName(missingName)
- availableWidget.blockSignals(previous)
+ with blockSignals(availableWidget):
+ availableWidget.setAxisName(missingName)
else:
# there is a duplicated name somewhere
# we swap it with the missing name or with nothing
@@ -387,9 +394,8 @@ class NumpyAxesSelector(qt.QWidget):
break
if missingName is None:
missingName = ""
- previous = dupWidget.blockSignals(True)
- dupWidget.setAxisName(missingName)
- dupWidget.blockSignals(previous)
+ with blockSignals(dupWidget):
+ dupWidget.setAxisName(missingName)
if self.__data is None:
return
@@ -402,70 +408,164 @@ class NumpyAxesSelector(qt.QWidget):
It fires a `selectionChanged` event.
"""
- if self.__data is None:
+ permutation = self.permutation()
+
+ if self.__data is None or permutation is None:
+ # No data or not all the expected axes are there
if self.__selectedData is not None:
self.__selectedData = None
- self.__selection = tuple()
self.selectionChanged.emit()
return
- selection = []
- axisNames = []
- for slider in self.__axis:
- name = slider.axisName()
- if name == "":
- selection.append(slider.value())
- else:
- selection.append(slice(None))
- axisNames.append(name)
- self.__selection = tuple(selection)
# get a view with few fixed dimensions
# with a h5py dataset, it create a copy
# TODO we can reuse the same memory in case of a copy
- view = self.__data[self.__selection]
-
- if set(self.__axisNames) - set(axisNames) != set([]):
- # Not all the expected axis are there
- if self.__selectedData is not None:
- self.__selectedData = None
- self.__selection = tuple()
- self.selectionChanged.emit()
- return
-
- # order axis as expected
- source = []
- destination = []
- order = []
- for index, name in enumerate(self.__axisNames):
- destination.append(index)
- source.append(axisNames.index(name))
- for _, s in sorted(zip(destination, source)):
- order.append(s)
- view = numpy.transpose(view, order)
-
- self.__selectedData = view
+ self.__selectedData = numpy.transpose(self.__data[self.selection()], permutation)
self.selectionChanged.emit()
def data(self):
"""Returns the input data.
- :rtype: numpy.ndarray
+ :rtype: Union[numpy.ndarray,None]
"""
- return self.__data
+ if self.__data is None:
+ return None
+ else:
+ return numpy.array(self.__data, copy=False)
def selectedData(self):
"""Returns the output data.
- :rtype: numpy.ndarray
+ This is equivalent to::
+
+ numpy.transpose(self.data()[self.selection()], self.permutation())
+
+ :rtype: Union[numpy.ndarray,None]
"""
- return self.__selectedData
+ if self.__selectedData is None:
+ return None
+ else:
+ return numpy.array(self.__selectedData, copy=False)
+
+ def permutation(self):
+ """Returns the axes permutation to convert data subset to selected data.
+
+ If permutation cannot be computer, it returns None.
+
+ :rtype: Union[List[int],None]
+ """
+ if self.__data is None:
+ return None
+ else:
+ indices = []
+ for name in self.__axisNames:
+ index = 0
+ for axis in self.__axis:
+ if axis.axisName() == name:
+ indices.append(index)
+ break
+ if axis.axisName() != "":
+ index += 1
+ else:
+ _logger.warning("No axis corresponding to: %s", name)
+ return None
+ return tuple(indices)
def selection(self):
"""Returns the selection tuple used to slice the data.
:rtype: tuple
"""
- return self.__selection
+ if self.__data is None:
+ return tuple()
+ else:
+ return tuple([axis.value() if axis.axisName() == "" else slice(None)
+ for axis in self.__axis])
+
+ def setSelection(self, selection, permutation=None):
+ """Set the selection along each dimension.
+
+ tuple returned by :meth:`selection` can be provided as input,
+ provided that it is for the same the number of axes and
+ the same number of dimensions of the data.
+
+ :param List[Union[int,slice,None]] selection:
+ The selection tuple with as one element for each dimension of the data.
+ If an element is None, then the whole dimension is selected.
+ :param Union[List[int],None] permutation:
+ The data axes indices to transpose.
+ If not given, no permutation is applied
+ :raise ValueError:
+ When the selection does not match current data shape and number of axes.
+ """
+ data_shape = self.__data.shape if self.__data is not None else ()
+
+ # Check selection
+ if len(selection) != len(data_shape):
+ raise ValueError(
+ "Selection length (%d) and data ndim (%d) mismatch" %
+ (len(selection), len(data_shape)))
+
+ # Check selection type
+ selectedDataNDim = 0
+ for element, size in zip(selection, data_shape):
+ if isinstance(element, int):
+ if not 0 <= element < size:
+ raise ValueError(
+ "Selected index (%d) outside data dimension range [0-%d]" %
+ (element, size))
+ elif element is None or element == slice(None):
+ selectedDataNDim += 1
+ else:
+ raise ValueError("Unsupported element in selection: %s" % element)
+
+ ndim = len(self.__axisNames)
+ if selectedDataNDim != ndim:
+ raise ValueError(
+ "Selection dimensions (%d) and number of axes (%d) mismatch" %
+ (selectedDataNDim, ndim))
+
+ # check permutation
+ if permutation is None:
+ permutation = tuple(range(ndim))
+
+ if set(permutation) != set(range(ndim)):
+ raise ValueError(
+ "Error in provided permutation: "
+ "Wrong size, elements out of range or duplicates")
+
+ inversePermutation = numpy.argsort(permutation)
+
+ axisNameChanged = False
+ customValueChanged = []
+ with blockSignals(*self.__axis):
+ index = 0
+ for element, axis in zip(selection, self.__axis):
+ if isinstance(element, int):
+ name = ""
+ else:
+ name = self.__axisNames[inversePermutation[index]]
+ index += 1
+
+ if axis.axisName() != name:
+ axis.setAxisName(name)
+ axisNameChanged = True
+
+ for element, axis in zip(selection, self.__axis):
+ value = element if isinstance(element, int) else 0
+ if axis.value() != value:
+ axis.setValue(value)
+
+ name = axis.axisName()
+ if name in self.__customAxisNames:
+ customValueChanged.append((name, value))
+
+ # Send signals that where disabled
+ if axisNameChanged:
+ self.selectedAxisChanged.emit()
+ for name, value in customValueChanged:
+ self.customAxisChanged.emit(name, value)
+ self.__updateSelectedData()
def setNamedAxesSelectorVisibility(self, visible):
"""Show or hide the combo-boxes allowing to map the plot axes