diff options
Diffstat (limited to 'silx/io/nxdata.py')
-rw-r--r-- | silx/io/nxdata.py | 669 |
1 files changed, 586 insertions, 83 deletions
diff --git a/silx/io/nxdata.py b/silx/io/nxdata.py index 977721f..cc153b0 100644 --- a/silx/io/nxdata.py +++ b/silx/io/nxdata.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017 European Synchrotron Radiation Facility +# Copyright (c) 2017-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -29,30 +29,47 @@ See http://download.nexusformat.org/sphinx/classes/base_classes/NXdata.html """ import logging +import os +import os.path import numpy -from .utils import is_dataset, is_group +from .utils import is_dataset, is_group, is_file from silx.third_party import six +try: + import h5py +except ImportError: + h5py = None + +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "12/02/2018" + _logger = logging.getLogger(__name__) _INTERPDIM = {"scalar": 0, "spectrum": 1, "image": 2, - # "rgba-image": 3, "hsla-image": 3, "cmyk-image": 3, # TODO + "rgba-image": 3, # "hsla-image": 3, "cmyk-image": 3, # TODO "vertex": 1} # 3D scatter: 1D signal + 3 axes (x, y, z) of same legth """Number of signal dimensions associated to each possible @interpretation attribute. """ -def _nxdata_warning(msg): +def _nxdata_warning(msg, group_name=""): """Log a warning message prefixed with *"NXdata warning: "* :param str msg: Warning message + :param str group_name: Name of NXdata group this warning relates to """ - _logger.warning("NXdata warning: " + msg) + warning_prefix = "NXdata warning" + if group_name: + warning_prefix += " (group %s): " % group_name + else: + warning_prefix += ": " + _logger.warning(warning_prefix + msg) def get_attr_as_string(item, attr_name, default=None): @@ -64,16 +81,26 @@ def get_attr_as_string(item, attr_name, default=None): :param item: Group or dataset :param attr_name: Attribute name + :param default: Value to be returned if attribute is not found. :return: item.attrs[attr_name] """ attr = item.attrs.get(attr_name, default) if six.PY2: - return attr + if isinstance(attr, six.text_type): + # unicode + return attr.encode("utf-8") + else: + return attr if six.PY3: if hasattr(attr, "decode"): # byte-string return attr.decode("utf-8") - elif isinstance(attr, numpy.ndarray) and hasattr(attr[0], "decode"): + elif isinstance(attr, numpy.ndarray) and not attr.shape and\ + hasattr(attr[()], "decode"): + # byte string as ndarray scalar + return attr[()].decode("utf-8") + elif isinstance(attr, numpy.ndarray) and len(attr.shape) and\ + hasattr(attr[0], "decode"): # array of byte-strings return [element.decode("utf-8") for element in attr] else: @@ -92,7 +119,7 @@ def is_valid_nxdata(group): # noqa :param group: h5py-like group :return: True if this NXdata group is valid. - :raise: TypeError if group is not a h5py group, a spech5 group, + :raise TypeError: if group is not a h5py group, a spech5 group, or a fabioh5 group """ if not is_group(group): @@ -100,28 +127,60 @@ def is_valid_nxdata(group): # noqa if get_attr_as_string(group, "NX_class") != "NXdata": return False if "signal" not in group.attrs: - _logger.warning("NXdata group does not define a signal attr.") - return False + _logger.info("NXdata group %s does not define a signal attr. " + "Testing legacy specification.", group.name) + signal_name = None + for key in group: + if "signal" in group[key].attrs: + signal_name = key + signal_attr = group[key].attrs["signal"] + if signal_attr in [1, b"1", u"1"]: + # This is the main (default) signal + break + if signal_name is None: + _nxdata_warning("No @signal attribute on the NXdata group, " + "and no dataset with a @signal=1 attr found", + group.name) + return False + else: + signal_name = get_attr_as_string(group, "signal") - signal_name = get_attr_as_string(group, "signal") if signal_name not in group or not is_dataset(group[signal_name]): - _logger.warning( - "Cannot find signal dataset '%s' in NXdata group" % signal_name) + _nxdata_warning( + "Cannot find signal dataset '%s'" % signal_name, + group.name) return False + auxiliary_signals_names = get_attr_as_string(group, "auxiliary_signals", + default=[]) + if isinstance(auxiliary_signals_names, (six.text_type, six.binary_type)): + auxiliary_signals_names = [auxiliary_signals_names] + for asn in auxiliary_signals_names: + if asn not in group or not is_dataset(group[asn]): + _nxdata_warning( + "Cannot find auxiliary signal dataset '%s'" % asn, + group.name) + return False + if group[signal_name].shape != group[asn].shape: + _nxdata_warning("Auxiliary signal dataset '%s' does not" % asn + + " have the same shape as the main signal.", + group.name) + return False + ndim = len(group[signal_name].shape) if "axes" in group.attrs: axes_names = get_attr_as_string(group, "axes") - if isinstance(axes_names, str): + if isinstance(axes_names, (six.text_type, six.binary_type)): axes_names = [axes_names] if 1 < ndim < len(axes_names): - # ndim = 1 and several axes could be a scatter + # ndim = 1 with several axes could be a scatter _nxdata_warning( "More @axes defined than there are " + "signal dimensions: " + - "%d axes, %d dimensions." % (len(axes_names), ndim)) + "%d axes, %d dimensions." % (len(axes_names), ndim), + group.name) return False # case of less axes than dimensions: number of axes must match @@ -132,18 +191,33 @@ def is_valid_nxdata(group): # noqa interpretation = get_attr_as_string(group, "interpretation") if interpretation is None: _nxdata_warning("No @interpretation and not enough" + - " @axes defined.") + " @axes defined.", group.name) return False if interpretation not in _INTERPDIM: _nxdata_warning("Unrecognized @interpretation=" + interpretation + - " for data with wrong number of defined @axes.") + " for data with wrong number of defined @axes.", + group.name) return False + if interpretation == "rgba-image": + if ndim != 3 or group[signal_name].shape[-1] not in [3, 4]: + _nxdata_warning( + "Inconsistent RGBA Image. Expected 3 dimensions with " + + "last one of length 3 or 4. Got ndim=%d " % ndim + + "with last dimension of length %d." % group[signal_name].shape[-1], + group.name) + return False + if len(axes_names) != 2: + _nxdata_warning( + "Inconsistent number of axes for RGBA Image. Expected " + "3, but got %d." % ndim, group.name) + return False - if len(axes_names) != _INTERPDIM[interpretation]: + elif len(axes_names) != _INTERPDIM[interpretation]: _nxdata_warning( "%d-D signal with @interpretation=%s " % (ndim, interpretation) + - "must define %d or %d axes." % (ndim, _INTERPDIM[interpretation])) + "must define %d or %d axes." % (ndim, _INTERPDIM[interpretation]), + group.name) return False # Test consistency of @uncertainties @@ -155,7 +229,7 @@ def is_valid_nxdata(group): # noqa if uncertainties_names is not None: if len(uncertainties_names) != len(axes_names): _nxdata_warning("@uncertainties does not define the same " + - "number of fields than @axes") + "number of fields than @axes", group.name) return False # Test individual axes @@ -165,10 +239,12 @@ def is_valid_nxdata(group): # noqa signal_size *= dim polynomial_axes_names = [] for i, axis_name in enumerate(axes_names): + if axis_name == ".": continue if axis_name not in group or not is_dataset(group[axis_name]): - _nxdata_warning("Could not find axis dataset '%s'" % axis_name) + _nxdata_warning("Could not find axis dataset '%s'" % axis_name, + group.name) return False axis_size = 1 @@ -180,7 +256,8 @@ def is_valid_nxdata(group): # noqa # size is exactly the signal's size (weird n-d scatter) if axis_size != signal_size: _nxdata_warning("Axis %s is not a 1D dataset" % axis_name + - " and its shape does not match the signal's shape") + " and its shape does not match the signal's shape", + group.name) return False axis_len = axis_size else: @@ -195,7 +272,7 @@ def is_valid_nxdata(group): # noqa "Axis %s number of elements does not " % axis_name + "correspond to the length of any signal dimension," " it does not appear to be a constant or a linear calibration," + - " and this does not seem to be a scatter plot.") + " and this does not seem to be a scatter plot.", group.name) return False elif axis_len in (1, 2): polynomial_axes_names.append(axis_name) @@ -205,7 +282,8 @@ def is_valid_nxdata(group): # noqa _nxdata_warning( "Axis %s number of elements is equal " % axis_name + "to the length of the signal, but this does not seem" + - " to be a scatter (other axes have different sizes)") + " to be a scatter (other axes have different sizes)", + group.name) return False # Test individual uncertainties @@ -216,14 +294,15 @@ def is_valid_nxdata(group): # noqa if group[errors_name].shape != group[axis_name].shape: _nxdata_warning( "Errors '%s' does not have the same " % errors_name + - "dimensions as axis '%s'." % axis_name) + "dimensions as axis '%s'." % axis_name, group.name) return False # test dimensions of errors associated with signal if "errors" in group and is_dataset(group["errors"]): if group["errors"].shape != group[signal_name].shape: _nxdata_warning("Dataset containing standard deviations must " + - "have the same dimensions as the signal.") + "have the same dimensions as the signal.", + group.name) return False return True @@ -245,10 +324,19 @@ class NXdata(object): """h5py-like group object compliant with NeXus NXdata specification. """ - self.signal = self.group[get_attr_as_string(self.group, "signal")] - """Signal dataset in this NXdata group. + self.signal = self.group[self.signal_dataset_name] + """Main signal dataset in this NXdata group. + + In case more than one signal is present in this group, + the other ones can be found in :attr:`auxiliary_signals`. """ + self.signal_name = get_attr_as_string(self.signal, "long_name") + """Signal long name, as specified in the @long_name attribute of the + signal dataset. If not specified, the dataset name is used.""" + if self.signal_name is None: + self.signal_name = self.signal_dataset_name + # ndim will be available in very recent h5py versions only self.signal_ndim = getattr(self.signal, "ndim", len(self.signal.shape)) @@ -276,6 +364,86 @@ class NXdata(object): self.signal_is_1d = self.signal_is_1d and len(self.axes) <= 1 # excludes n-D scatters @property + def signal_dataset_name(self): + """Name of the main signal dataset.""" + signal_dataset_name = get_attr_as_string(self.group, "signal") + if signal_dataset_name is None: + # find a dataset with @signal == 1 + for dsname in self.group: + signal_attr = self.group[dsname].attrs.get("signal") + if signal_attr in [1, b"1", u"1"]: + # This is the main (default) signal + signal_dataset_name = dsname + break + assert signal_dataset_name is not None + return signal_dataset_name + + @property + def auxiliary_signals_dataset_names(self): + """Sorted list of names of the auxiliary signals datasets. + + These are the names provided by the *@auxiliary_signals* attribute + on the NXdata group. + + In case the NXdata group does not specify a *@signal* attribute + but has a dataset with an attribute *@signal=1*, + we look for datasets with attributes *@signal=2, @signal=3...* + (deprecated NXdata specification).""" + signal_dataset_name = get_attr_as_string(self.group, "signal") + if signal_dataset_name is not None: + auxiliary_signals_names = get_attr_as_string(self.group, "auxiliary_signals") + if auxiliary_signals_names is not None: + if not isinstance(auxiliary_signals_names, + (tuple, list, numpy.ndarray)): + # tolerate a single string, but coerce into a list + return [auxiliary_signals_names] + return list(auxiliary_signals_names) + return [] + + # try old spec, @signal=1 (2, 3...) on dataset + numbered_names = [] + for dsname in self.group: + if dsname == self.signal_dataset_name: + # main signal, not auxiliary + continue + ds = self.group[dsname] + signal_attr = ds.attrs.get("signal") + if signal_attr is not None and not is_dataset(ds): + _logger.warning("Item %s with @signal=%s is not a dataset (%s)", + dsname, signal_attr, type(ds)) + continue + if signal_attr is not None: + try: + signal_number = int(signal_attr) + except (ValueError, TypeError): + _logger.warning("Could not parse attr @signal=%s on " + "dataset %s as an int", + signal_attr, dsname) + continue + numbered_names.append((signal_number, dsname)) + return [a[1] for a in sorted(numbered_names)] + + @property + def auxiliary_signals_names(self): + """List of names of the auxiliary signals. + + Similar to :attr:`auxiliary_signals_dataset_names`, but the @long_name + is used when this attribute is present, instead of the dataset name. + """ + signal_names = [] + for asdn in self.auxiliary_signals_dataset_names: + if "long_name" in self.group[asdn].attrs: + signal_names.append(self.group[asdn].attrs["long_name"]) + else: + signal_names.append(asdn) + return signal_names + + @property + def auxiliary_signals(self): + """List of all auxiliary signal datasets.""" + return [self.group[dsname] for dsname in self.auxiliary_signals_dataset_names] + + @property def interpretation(self): """*@interpretation* attribute associated with the *signal* dataset of the NXdata group. ``None`` if no interpretation @@ -300,7 +468,7 @@ class NXdata(object): interpretation is returned anyway. """ allowed_interpretations = [None, "scalar", "spectrum", "image", - # "rgba-image", "hsla-image", "cmyk-image" # TODO + "rgba-image", # "hsla-image", "cmyk-image" "vertex"] interpretation = get_attr_as_string(self.signal, "interpretation") @@ -317,20 +485,19 @@ class NXdata(object): """List of the axes datasets. The list typically has as many elements as there are dimensions in the - signal dataset, the exception being scatter plots which typically - use a 1D signal and several 1D axes of the same size. + signal dataset, the exception being scatter plots which use a 1D + signal and multiple 1D axes of the same size. If an axis dataset applies to several dimensions of the signal, it will be repeated in the list. - If a dimension of the signal has no dimension scale (i.e. there is a - "." in that position in the *@axes* array), `None` is inserted in the - output list in its position. + If a dimension of the signal has no dimension scale, `None` is + inserted in its position in the list. .. note:: - In theory, the *@axes* attribute defines as many entries as there - are dimensions in the signal. In such a case, there is no ambiguity. + The *@axes* attribute should define as many entries as there + are dimensions in the signal, to avoid any ambiguity. If this is not the case, this implementation relies on the existence of an *@interpretation* (*spectrum* or *image*) attribute in the *signal* dataset. @@ -339,47 +506,20 @@ class NXdata(object): If an axis dataset defines attributes @first_good or @last_good, the output will be a numpy array resulting from slicing that - axis to keep only the good index range: axis[first_good:last_good + 1] + axis (*axis[first_good:last_good + 1]*). - :rtype: list[Dataset or 1D array or None] + :rtype: List[Dataset or 1D array or None] """ if self._axes is not None: # use cache return self._axes - ndims = len(self.signal.shape) - axes_names = get_attr_as_string(self.group, "axes") - interpretation = self.interpretation - - if axes_names is None: - self._axes = [None for _i in range(ndims)] - return self._axes - - if isinstance(axes_names, str): - axes_names = [axes_names] + axes = [] + for axis_name in self.axes_dataset_names: + if axis_name is None: + axes.append(None) + else: + axes.append(self.group[axis_name]) - if len(axes_names) == ndims: - # axes is a list of strings, one axis per dim is explicitly defined - axes = [None] * ndims - for i, axis_n in enumerate(axes_names): - if axis_n != ".": - axes[i] = self.group[axis_n] - elif interpretation is not None: - # case of @interpretation attribute defined: we expect 1, 2 or 3 axes - # corresponding to the 1, 2, or 3 last dimensions of the signal - assert len(axes_names) == _INTERPDIM[interpretation] - axes = [None] * (ndims - _INTERPDIM[interpretation]) - for axis_n in axes_names: - if axis_n != ".": - axes.append(self.group[axis_n]) - else: - axes.append(None) - else: # scatter - axes = [] - for axis_n in axes_names: - if axis_n != ".": - axes.append(self.group[axis_n]) - else: - axes.append(None) # keep only good range of axis data for i, axis in enumerate(axes): if axis is None: @@ -395,7 +535,8 @@ class NXdata(object): @property def axes_dataset_names(self): - """ + """List of axes dataset names. + If an axis dataset applies to several dimensions of the signal, its name will be repeated in the list. @@ -403,15 +544,46 @@ class NXdata(object): "." in that position in the *@axes* array), `None` is inserted in the output list in its position. """ + numbered_names = [] # used in case of @axis=0 (old spec) axes_dataset_names = get_attr_as_string(self.group, "axes") if axes_dataset_names is None: - axes_dataset_names = get_attr_as_string(self.group, "axes") + # try @axes on signal dataset (older NXdata specification) + axes_dataset_names = get_attr_as_string(self.signal, "axes") + if axes_dataset_names is not None: + # we expect a comma separated string + if hasattr(axes_dataset_names, "split"): + axes_dataset_names = axes_dataset_names.split(":") + else: + # try @axis on the individual datasets (oldest NXdata specification) + for dsname in self.group: + if not is_dataset(self.group[dsname]): + continue + axis_attr = self.group[dsname].attrs.get("axis") + if axis_attr is not None: + try: + axis_num = int(axis_attr) + except (ValueError, TypeError): + _logger.warning("Could not interpret attr @axis as" + "int on dataset %s", dsname) + continue + numbered_names.append((axis_num, dsname)) ndims = len(self.signal.shape) if axes_dataset_names is None: - return [None] * ndims + if numbered_names: + axes_dataset_names = [] + numbers = [a[0] for a in numbered_names] + names = [a[1] for a in numbered_names] + for i in range(ndims): + if i in numbers: + axes_dataset_names.append(names[numbers.index(i)]) + else: + axes_dataset_names.append(None) + return axes_dataset_names + else: + return [None] * ndims - if isinstance(axes_dataset_names, str): + if isinstance(axes_dataset_names, (six.text_type, six.binary_type)): axes_dataset_names = [axes_dataset_names] for i, axis_name in enumerate(axes_dataset_names): @@ -422,17 +594,48 @@ class NXdata(object): if len(axes_dataset_names) != ndims: if self.is_scatter and ndims == 1: + # case of a 1D signal with arbitrary number of axes return list(axes_dataset_names) - # @axes may only define 1 or 2 axes if @interpretation=spectrum/image. - # Use the existing names for the last few dims, and prepend with Nones. - assert len(axes_dataset_names) == _INTERPDIM[self.interpretation] - all_dimensions_names = [None] * (ndims - _INTERPDIM[self.interpretation]) - for axis_name in axes_dataset_names: - all_dimensions_names.append(axis_name) + if self.interpretation != "rgba-image": + # @axes may only define 1 or 2 axes if @interpretation=spectrum/image. + # Use the existing names for the last few dims, and prepend with Nones. + assert len(axes_dataset_names) == _INTERPDIM[self.interpretation] + all_dimensions_names = [None] * (ndims - _INTERPDIM[self.interpretation]) + for axis_name in axes_dataset_names: + all_dimensions_names.append(axis_name) + else: + # 2 axes applying to the first two dimensions. + # The 3rd signal dimension is expected to contain 3(4) RGB(A) values. + assert len(axes_dataset_names) == 2 + all_dimensions_names = [axn for axn in axes_dataset_names] + all_dimensions_names.append(None) return all_dimensions_names return list(axes_dataset_names) + @property + def title(self): + """Plot title. If not found, returns an empty string. + + This attribute does not appear in the NXdata specification, but it is + implemented in *nexpy* as a dataset named "title" inside the NXdata + group. This dataset is expected to contain text. + + Because the *nexpy* approach could cause a conflict if the signal + dataset or an axis dataset happened to be called "title", we also + support providing the title as an attribute of the NXdata group. + """ + title = self.group.get("title") + data_dataset_names = [self.signal_name] + self.axes_dataset_names + if (title is not None and is_dataset(title) and + "title" not in data_dataset_names): + return str(title[()]) + + title = self.group.attrs.get("title") + if title is None: + return "" + return str(title) + def get_axis_errors(self, axis_name): """Return errors (uncertainties) associated with an axis. @@ -442,7 +645,7 @@ class NXdata(object): :param str axis_name: Name of axis dataset. This dataset **must exist**. :return: Dataset with axis errors, or None - :raise: KeyError if this group does not contain a dataset named axis_name + :raise KeyError: if this group does not contain a dataset named axis_name """ # ensure axis_name is decoded, before comparing it with decoded attributes if hasattr(axis_name, "decode"): @@ -541,3 +744,303 @@ class NXdata(object): def is_unsupported_scatter(self): """True if this is a scatter with a signal and more than 2 axes.""" return self.is_scatter and len(self.axes) > 2 + + @property + def is_curve(self): + """This property is True if the signal is 1D or :attr:`interpretation` is + *"spectrum"*, and there is at most one axis with a consistent length. + """ + if self.signal_is_0d or self.interpretation not in [None, "spectrum"]: + return False + # the axis, if any, must be of the same length as the last dimension + # of the signal, or of length 2 (a + b *x scale) + if self.axes[-1] is not None and len(self.axes[-1]) not in [ + self.signal.shape[-1], 2]: + return False + if self.interpretation is None: + # We no longer test whether x values are monotonic + # (in the past, in that case, we used to consider it a scatter) + return self.signal_is_1d + # everything looks good + return True + + @property + def is_image(self): + """True if the signal is 2D, or 3D with last dimension of length 3 or 4 + and interpretation *rgba-image*, or >2D with interpretation *image*. + The axes (if any) length must also be consistent with the signal shape. + """ + if self.interpretation in ["scalar", "spectrum", "scaler"]: + return False + if self.signal_is_0d or self.signal_is_1d: + return False + if not self.signal_is_2d and \ + self.interpretation not in ["image", "rgba-image"]: + return False + if self.signal_is_3d and self.interpretation == "rgba-image": + if self.signal.shape[-1] not in [3, 4]: + return False + img_axes = self.axes[0:2] + img_shape = self.signal.shape[0:2] + else: + img_axes = self.axes[-2:] + img_shape = self.signal.shape[-2:] + for i, axis in enumerate(img_axes): + if axis is not None and len(axis) not in [img_shape[i], 2]: + return False + + return True + + @property + def is_stack(self): + """True in the signal is at least 3D and interpretation is not + "scalar", "spectrum", "image" or "rgba-image". + The axes length must also be consistent with the last 3 dimensions + of the signal. + """ + if self.signal_ndim < 3 or self.interpretation in [ + "scalar", "scaler", "spectrum", "image", "rgba-image"]: + return False + stack_shape = self.signal.shape[-3:] + for i, axis in enumerate(self.axes[-3:]): + if axis is not None and len(axis) not in [stack_shape[i], 2]: + return False + return True + + +def is_NXentry_with_default_NXdata(group): + """Return True if group is a valid NXentry defining a valid default + NXdata.""" + if not is_group(group): + return False + + if get_attr_as_string(group, "NX_class") != "NXentry": + return False + + default_nxdata_name = group.attrs.get("default") + if default_nxdata_name is None or default_nxdata_name not in group: + return False + + default_nxdata_group = group.get(default_nxdata_name) + + if not is_group(default_nxdata_group): + return False + + return is_valid_nxdata(default_nxdata_group) + + +def is_NXroot_with_default_NXdata(group): + """Return True if group is a valid NXroot defining a default NXentry + defining a valid default NXdata.""" + if not is_group(group): + return False + + # A NXroot is supposed to be at the root of a data file, and @NX_class + # is therefore optional. We accept groups that are not located at the root + # if they have @NX_class=NXroot (use case: several nexus files archived + # in a single HDF5 file) + if get_attr_as_string(group, "NX_class") != "NXroot" and not is_file(group): + return False + + default_nxentry_name = group.attrs.get("default") + if default_nxentry_name is None or default_nxentry_name not in group: + return False + + default_nxentry_group = group.get(default_nxentry_name) + return is_NXentry_with_default_NXdata(default_nxentry_group) + + +def get_default(group): + """Return a :class:`NXdata` object corresponding to the default NXdata group + in the group specified as parameter. + + This function can find the NXdata if the group is already a NXdata, or + if it is a NXentry defining a default NXdata, or if it is a NXroot + defining such a default valid NXentry. + + Return None if no valid NXdata could be found. + + :param group: h5py-like group following the Nexus specification + (NXdata, NXentry or NXroot). + :return: :class:`NXdata` object or None + :raise TypeError: if group is not a h5py-like group + """ + if not is_group(group): + raise TypeError("Provided parameter is not a h5py-like group") + + if is_NXroot_with_default_NXdata(group): + default_entry = group[group.attrs["default"]] + default_data = default_entry[default_entry.attrs["default"]] + elif is_NXentry_with_default_NXdata(group): + default_data = group[group.attrs["default"]] + elif is_valid_nxdata(group): + default_data = group + else: + return None + + return NXdata(default_data) + + +def _str_to_utf8(text): + return numpy.array(text, dtype=h5py.special_dtype(vlen=six.text_type)) + + +def save_NXdata(filename, signal, axes=None, + signal_name="data", axes_names=None, + signal_long_name=None, axes_long_names=None, + signal_errors=None, axes_errors=None, + title=None, interpretation=None, + nxentry_name="entry", nxdata_name=None): + """Write data to an NXdata group. + + .. note:: + + No consistency checks are made regarding the dimensionality of the + signal and number of axes. The user is responsible for providing + meaningful data, that can be interpreted by visualization software. + + :param str filename: Path to output file. If the file does not + exists, it is created. + :param numpy.ndarray signal: Signal array. + :param List[numpy.ndarray] axes: List of axes arrays. + :param str signal_name: Name of signal dataset, in output file + :param List[str] axes_names: List of dataset names for axes, in + output file + :param str signal_long_name: *@long_name* attribute for signal, or None. + :param axes_long_names: None, or list of long names + for axes + :type axes_long_names: List[str, None] + :param numpy.ndarray signal_errors: Array of errors associated with the + signal + :param axes_errors: List of arrays of errors + associated with each axis + :type axes_errors: List[numpy.ndarray, None] + :param str title: Graph title (saved as a "title" dataset) or None. + :param str interpretation: *@interpretation* attribute ("spectrum", + "image", "rgba-image" or None). This is only needed in cases of + ambiguous dimensionality, e.g. a 3D array which represents a RGBA + image rather than a stack. + :param str nxentry_name: Name of group in which the NXdata group + is created. By default, "/entry" is used. + + .. note:: + + The Nexus format specification requires for NXdata groups + be part of a NXentry group. + The specified group should have attribute *@NX_class=NXentry*, in + order for the created file to be nexus compliant. + :param str nxdata_name: Name of NXdata group. If omitted (None), the + function creates a new group using the first available name ("data0", + or "data1"...). + Overwriting an existing group (or dataset) is not supported, you must + delete it yourself prior to calling this function if this is what you + want. + :return: True if save was successful, else False. + """ + if h5py is None: + raise ImportError("h5py could not be imported, but is required by " + "save_NXdata function") + + if axes_names is not None: + assert axes is not None, "Axes names defined, but missing axes arrays" + assert len(axes) == len(axes_names), \ + "Mismatch between number of axes and axes_names" + + if axes is not None and axes_names is None: + axes_names = [] + for i, axis in enumerate(axes): + axes_names.append("dim%d" % i if axis is not None else ".") + if axes is None: + axes = [] + + # Open file in + if os.path.exists(filename): + errmsg = "Cannot write/append to existing path %s" + if not os.path.isfile(filename): + errmsg += " (not a file)" + _logger.error(errmsg, filename) + return False + if not os.access(filename, os.W_OK): + errmsg += " (no permission to write)" + _logger.error(errmsg, filename) + return False + mode = "r+" + else: + mode = "w-" + + with h5py.File(filename, mode=mode) as h5f: + # get or create entry + if nxentry_name is not None: + entry = h5f.require_group(nxentry_name) + if "default" not in h5f.attrs: + # set this entry as default + h5f.attrs["default"] = _str_to_utf8(nxentry_name) + if "NX_class" not in entry.attrs: + entry.attrs["NX_class"] = u"NXentry" + else: + # write NXdata into the root of the file (invalid nexus!) + entry = h5f + + # Create NXdata group + if nxdata_name is not None: + if nxdata_name in entry: + _logger.error("Cannot assign an NXdata group to an existing" + " group or dataset") + return False + else: + # no name specified, take one that is available + nxdata_name = "data0" + i = 1 + while nxdata_name in entry: + _logger.info("%s item already exists in NXentry group," + + " trying %s", nxdata_name, "data%d" % i) + nxdata_name = "data%d" % i + i += 1 + + data_group = entry.create_group(nxdata_name) + data_group.attrs["NX_class"] = u"NXdata" + data_group.attrs["signal"] = _str_to_utf8(signal_name) + if axes: + data_group.attrs["axes"] = _str_to_utf8(axes_names) + if title: + # not in NXdata spec, but implemented by nexpy + data_group["title"] = title + # better way imho + data_group.attrs["title"] = _str_to_utf8(title) + + signal_dataset = data_group.create_dataset(signal_name, + data=signal) + if signal_long_name: + signal_dataset.attrs["long_name"] = _str_to_utf8(signal_long_name) + if interpretation: + signal_dataset.attrs["interpretation"] = _str_to_utf8(interpretation) + + for i, axis_array in enumerate(axes): + if axis_array is None: + assert axes_names[i] in [".", None], \ + "Axis name defined for dim %d but no axis array" % i + continue + axis_dataset = data_group.create_dataset(axes_names[i], + data=axis_array) + if axes_long_names is not None: + axis_dataset.attrs["long_name"] = _str_to_utf8(axes_long_names[i]) + + if signal_errors is not None: + data_group.create_dataset("errors", + data=signal_errors) + + if axes_errors is not None: + assert isinstance(axes_errors, (list, tuple)), \ + "axes_errors must be a list or a tuple of ndarray or None" + assert len(axes_errors) == len(axes_names), \ + "Mismatch between number of axes_errors and axes_names" + for i, axis_errors in enumerate(axes_errors): + if axis_errors is not None: + dsname = axes_names[i] + "_errors" + data_group.create_dataset(dsname, + data=axis_errors) + if "default" not in entry.attrs: + # set this NXdata as default + entry.attrs["default"] = nxdata_name + + return True |