diff options
Diffstat (limited to 'silx/gui/plot/actions/io.py')
-rw-r--r-- | silx/gui/plot/actions/io.py | 133 |
1 files changed, 118 insertions, 15 deletions
diff --git a/silx/gui/plot/actions/io.py b/silx/gui/plot/actions/io.py index 50410e3..d6d5909 100644 --- a/silx/gui/plot/actions/io.py +++ b/silx/gui/plot/actions/io.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# Copyright (c) 2004-2018 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -37,10 +37,11 @@ from __future__ import division __authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"] __license__ = "MIT" -__date__ = "27/06/2017" +__date__ = "02/02/2018" from . import PlotAction from silx.io.utils import save1D, savespec +from silx.io.nxdata import save_NXdata import logging import sys from collections import OrderedDict @@ -59,6 +60,10 @@ else: _logger = logging.getLogger(__name__) +_NEXUS_HDF5_EXT = [".nx5", ".nxs", ".hdf", ".hdf5", ".cxi", ".h5"] +_NEXUS_HDF5_EXT_STR = ' '.join(['*' + ext for ext in _NEXUS_HDF5_EXT]) + + class SaveAction(PlotAction): """QAction for saving Plot content. @@ -89,12 +94,15 @@ class SaveAction(PlotAction): ('Curve as OMNIC CSV (*.csv)', {'fmt': '%.7E', 'delimiter': ',', 'header': False}), ('Curve as SpecFile (*.dat)', - {'fmt': '%.7g', 'delimiter': '', 'header': False}) + {'fmt': '%.10g', 'delimiter': '', 'header': False}) )) CURVE_FILTER_NPY = 'Curve as NumPy binary file (*.npy)' - CURVE_FILTERS = list(CURVE_FILTERS_TXT.keys()) + [CURVE_FILTER_NPY] + CURVE_FILTER_NXDATA = 'Curve as NXdata (%s)' % _NEXUS_HDF5_EXT_STR + + CURVE_FILTERS = list(CURVE_FILTERS_TXT.keys()) + [CURVE_FILTER_NPY, + CURVE_FILTER_NXDATA] ALL_CURVES_FILTERS = ("All curves as SpecFile (*.dat)", ) @@ -107,6 +115,7 @@ class SaveAction(PlotAction): IMAGE_FILTER_CSV_TAB = 'Image data as tab-separated CSV (*.csv)' IMAGE_FILTER_RGB_PNG = 'Image as PNG (*.png)' IMAGE_FILTER_RGB_TIFF = 'Image as TIFF (*.tif)' + IMAGE_FILTER_NXDATA = 'Image as NXdata (%s)' % _NEXUS_HDF5_EXT_STR IMAGE_FILTERS = (IMAGE_FILTER_EDF, IMAGE_FILTER_TIFF, IMAGE_FILTER_NUMPY, @@ -115,7 +124,11 @@ class SaveAction(PlotAction): IMAGE_FILTER_CSV_SEMICOLON, IMAGE_FILTER_CSV_TAB, IMAGE_FILTER_RGB_PNG, - IMAGE_FILTER_RGB_TIFF) + IMAGE_FILTER_RGB_TIFF, + IMAGE_FILTER_NXDATA) + + SCATTER_FILTER_NXDATA = 'Scatter as NXdata (%s)' % _NEXUS_HDF5_EXT_STR + SCATTER_FILTERS = (SCATTER_FILTER_NXDATA, ) def __init__(self, plot, parent=None): super(SaveAction, self).__init__( @@ -183,7 +196,7 @@ class SaveAction(PlotAction): csvdelim = filter_['delimiter'] autoheader = filter_['header'] else: - # .npy + # .npy or nxdata fmt, csvdelim, autoheader = ("", "", False) # If curve has no associated label, get the default from the plot @@ -194,6 +207,19 @@ class SaveAction(PlotAction): if ylabel is None: ylabel = self.plot.getYAxis().getLabel() + if nameFilter == self.CURVE_FILTER_NXDATA: + return save_NXdata( + filename, + signal=curve.getYData(copy=False), + axes=[curve.getXData(copy=False)], + signal_name="y", + axes_names=["x"], + signal_long_name=ylabel, + axes_long_names=[xlabel], + signal_errors=curve.getYErrorData(copy=False), + axes_errors=[curve.getXErrorData(copy=True)], + title=self.plot.getGraphTitle()) + try: save1D(filename, curve.getXData(copy=False), @@ -226,11 +252,13 @@ class SaveAction(PlotAction): curve = curves[0] scanno = 1 try: + xlabel = curve.getXLabel() or self.plot.getGraphXLabel() + ylabel = curve.getYLabel() or self.plot.getGraphYLabel(curve.getYAxis()) specfile = savespec(filename, curve.getXData(copy=False), curve.getYData(copy=False), - curve.getXLabel(), - curve.getYLabel(), + xlabel, + ylabel, fmt="%.7g", scan_number=1, mode="w", write_file_header=True, close_file=False) @@ -241,12 +269,14 @@ class SaveAction(PlotAction): for curve in curves[1:]: try: scanno += 1 + xlabel = curve.getXLabel() or self.plot.getGraphXLabel() + ylabel = curve.getYLabel() or self.plot.getGraphYLabel(curve.getYAxis()) specfile = savespec(specfile, curve.getXData(copy=False), curve.getYData(copy=False), - curve.getXLabel(), - curve.getYLabel(), - fmt="%.7g", scan_number=scanno, mode="w", + xlabel, + ylabel, + fmt="%.7g", scan_number=scanno, write_file_header=False, close_file=False) except IOError: @@ -294,6 +324,24 @@ class SaveAction(PlotAction): return False return True + elif nameFilter == self.IMAGE_FILTER_NXDATA: + xorigin, yorigin = image.getOrigin() + xscale, yscale = image.getScale() + xaxis = xorigin + xscale * numpy.arange(data.shape[1]) + yaxis = yorigin + yscale * numpy.arange(data.shape[0]) + xlabel = image.getXLabel() or self.plot.getGraphXLabel() + ylabel = image.getYLabel() or self.plot.getGraphYLabel() + interpretation = "image" if len(data.shape) == 2 else "rgba-image" + + return save_NXdata(filename, + signal=data, + axes=[yaxis, xaxis], + signal_name="image", + axes_names=["y", "x"], + axes_long_names=[ylabel, xlabel], + title=self.plot.getGraphTitle(), + interpretation=interpretation) + elif nameFilter in (self.IMAGE_FILTER_ASCII, self.IMAGE_FILTER_CSV_COMMA, self.IMAGE_FILTER_CSV_SEMICOLON, @@ -343,6 +391,45 @@ class SaveAction(PlotAction): return False + def _saveScatter(self, filename, nameFilter): + """Save an image from the plot. + + :param str filename: The name of the file to write + :param str nameFilter: The selected name filter + :return: False if format is not supported or save failed, + True otherwise. + """ + if nameFilter not in self.SCATTER_FILTERS: + return False + + if nameFilter == self.SCATTER_FILTER_NXDATA: + scatter = self.plot.getScatter() + # TODO: we could get all scatters on this plot and concatenate their (x, y, values) + x = scatter.getXData(copy=False) + y = scatter.getYData(copy=False) + z = scatter.getValueData(copy=False) + + xerror = scatter.getXErrorData(copy=False) + if isinstance(xerror, float): + xerror = xerror * numpy.ones(x.shape, dtype=numpy.float32) + + yerror = scatter.getYErrorData(copy=False) + if isinstance(yerror, float): + yerror = yerror * numpy.ones(x.shape, dtype=numpy.float32) + + xlabel = self.plot.getGraphXLabel() + ylabel = self.plot.getGraphYLabel() + + return save_NXdata( + filename, + signal=z, + axes=[x, y], + signal_name="values", + axes_names=["x", "y"], + axes_long_names=[xlabel, ylabel], + axes_errors=[xerror, yerror], + title=self.plot.getGraphTitle()) + def _actionTriggered(self, checked=False): """Handle save action.""" # Set-up filters @@ -359,6 +446,11 @@ class SaveAction(PlotAction): if len(self.plot.getAllCurves()) > 1: filters.extend(self.ALL_CURVES_FILTERS) + # Add scatter filters if there is a scatter + # todo: CSV + if self.plot.getScatter() is not None: + filters.extend(self.SCATTER_FILTERS) + filters.extend(self.SNAPSHOT_FILTERS) # Create and run File dialog @@ -378,10 +470,19 @@ class SaveAction(PlotAction): dialog.close() # Forces the filename extension to match the chosen filter - extension = nameFilter.split()[-1][2:-1] - if (len(filename) <= len(extension) or - filename[-len(extension):].lower() != extension.lower()): - filename += extension + if "NXdata" in nameFilter: + has_allowed_ext = False + for ext in _NEXUS_HDF5_EXT: + if (len(filename) > len(ext) and + filename[-len(ext):].lower() == ext.lower()): + has_allowed_ext = True + if not has_allowed_ext: + filename += ".h5" + else: + default_extension = nameFilter.split()[-1][2:-1] + if (len(filename) <= len(default_extension) or + filename[-len(default_extension):].lower() != default_extension.lower()): + filename += default_extension # Handle save if nameFilter in self.SNAPSHOT_FILTERS: @@ -392,6 +493,8 @@ class SaveAction(PlotAction): return self._saveCurves(filename, nameFilter) elif nameFilter in self.IMAGE_FILTERS: return self._saveImage(filename, nameFilter) + elif nameFilter in self.SCATTER_FILTERS: + return self._saveScatter(filename, nameFilter) else: _logger.warning('Unsupported file filter: %s', nameFilter) return False |