diff options
Diffstat (limited to 'silx/gui/plot/actions/io.py')
-rw-r--r-- | silx/gui/plot/actions/io.py | 344 |
1 files changed, 224 insertions, 120 deletions
diff --git a/silx/gui/plot/actions/io.py b/silx/gui/plot/actions/io.py index d6d5909..ac06942 100644 --- a/silx/gui/plot/actions/io.py +++ b/silx/gui/plot/actions/io.py @@ -44,13 +44,16 @@ from silx.io.utils import save1D, savespec from silx.io.nxdata import save_NXdata import logging import sys +import os.path from collections import OrderedDict import traceback import numpy -from silx.gui import qt +from silx.utils.deprecation import deprecated +from silx.gui import qt, printer +from silx.gui.dialog.GroupDialog import GroupDialog from silx.third_party.EdfFile import EdfFile from silx.third_party.TiffIO import TiffIO -from silx.gui._utils import convertArrayToQImage +from ...utils._image import convertArrayToQImage if sys.version_info[0] == 3: from io import BytesIO else: @@ -60,10 +63,26 @@ else: _logger = logging.getLogger(__name__) -_NEXUS_HDF5_EXT = [".nx5", ".nxs", ".hdf", ".hdf5", ".cxi", ".h5"] +_NEXUS_HDF5_EXT = [".h5", ".nx5", ".nxs", ".hdf", ".hdf5", ".cxi"] _NEXUS_HDF5_EXT_STR = ' '.join(['*' + ext for ext in _NEXUS_HDF5_EXT]) +def selectOutputGroup(h5filename): + """Open a dialog to prompt the user to select a group in + which to output data. + + :param str h5filename: name of an existing HDF5 file + :rtype: str + :return: Name of output group, or None if the dialog was cancelled + """ + dialog = GroupDialog() + dialog.addFile(h5filename) + dialog.setWindowTitle("Select an output group") + if not dialog.exec_(): + return None + return dialog.getSelectedDataUrl().data_path() + + class SaveAction(PlotAction): """QAction for saving Plot content. @@ -72,12 +91,11 @@ class SaveAction(PlotAction): :param plot: :class:`.PlotWidget` instance on which to operate. :param parent: See :class:`QAction`. """ - # TODO find a way to make the filter list selectable and extensible SNAPSHOT_FILTER_SVG = 'Plot Snapshot as SVG (*.svg)' SNAPSHOT_FILTER_PNG = 'Plot Snapshot as PNG (*.png)' - SNAPSHOT_FILTERS = (SNAPSHOT_FILTER_PNG, SNAPSHOT_FILTER_SVG) + DEFAULT_ALL_FILTERS = (SNAPSHOT_FILTER_PNG, SNAPSHOT_FILTER_SVG) # Dict of curve filters with CSV-like format # Using ordered dict to guarantee filters order @@ -101,10 +119,10 @@ class SaveAction(PlotAction): CURVE_FILTER_NXDATA = 'Curve as NXdata (%s)' % _NEXUS_HDF5_EXT_STR - CURVE_FILTERS = list(CURVE_FILTERS_TXT.keys()) + [CURVE_FILTER_NPY, - CURVE_FILTER_NXDATA] + DEFAULT_CURVE_FILTERS = list(CURVE_FILTERS_TXT.keys()) + [ + CURVE_FILTER_NPY, CURVE_FILTER_NXDATA] - ALL_CURVES_FILTERS = ("All curves as SpecFile (*.dat)", ) + DEFAULT_ALL_CURVES_FILTERS = ("All curves as SpecFile (*.dat)",) IMAGE_FILTER_EDF = 'Image data as EDF (*.edf)' IMAGE_FILTER_TIFF = 'Image data as TIFF (*.tif)' @@ -114,23 +132,53 @@ class SaveAction(PlotAction): IMAGE_FILTER_CSV_SEMICOLON = 'Image data as ;-separated CSV (*.csv)' IMAGE_FILTER_CSV_TAB = 'Image data as tab-separated CSV (*.csv)' IMAGE_FILTER_RGB_PNG = 'Image as PNG (*.png)' - IMAGE_FILTER_RGB_TIFF = 'Image as TIFF (*.tif)' IMAGE_FILTER_NXDATA = 'Image as NXdata (%s)' % _NEXUS_HDF5_EXT_STR - IMAGE_FILTERS = (IMAGE_FILTER_EDF, - IMAGE_FILTER_TIFF, - IMAGE_FILTER_NUMPY, - IMAGE_FILTER_ASCII, - IMAGE_FILTER_CSV_COMMA, - IMAGE_FILTER_CSV_SEMICOLON, - IMAGE_FILTER_CSV_TAB, - IMAGE_FILTER_RGB_PNG, - IMAGE_FILTER_RGB_TIFF, - IMAGE_FILTER_NXDATA) + DEFAULT_IMAGE_FILTERS = (IMAGE_FILTER_EDF, + IMAGE_FILTER_TIFF, + IMAGE_FILTER_NUMPY, + IMAGE_FILTER_ASCII, + IMAGE_FILTER_CSV_COMMA, + IMAGE_FILTER_CSV_SEMICOLON, + IMAGE_FILTER_CSV_TAB, + IMAGE_FILTER_RGB_PNG, + IMAGE_FILTER_NXDATA) SCATTER_FILTER_NXDATA = 'Scatter as NXdata (%s)' % _NEXUS_HDF5_EXT_STR - SCATTER_FILTERS = (SCATTER_FILTER_NXDATA, ) + DEFAULT_SCATTER_FILTERS = (SCATTER_FILTER_NXDATA,) + + # filters for which we don't want an "overwrite existing file" warning + DEFAULT_APPEND_FILTERS = (CURVE_FILTER_NXDATA, IMAGE_FILTER_NXDATA, + SCATTER_FILTER_NXDATA) def __init__(self, plot, parent=None): + self._filters = { + 'all': OrderedDict(), + 'curve': OrderedDict(), + 'curves': OrderedDict(), + 'image': OrderedDict(), + 'scatter': OrderedDict()} + + # Initialize filters + for nameFilter in self.DEFAULT_ALL_FILTERS: + self.setFileFilter( + dataKind='all', nameFilter=nameFilter, func=self._saveSnapshot) + + for nameFilter in self.DEFAULT_CURVE_FILTERS: + self.setFileFilter( + dataKind='curve', nameFilter=nameFilter, func=self._saveCurve) + + for nameFilter in self.DEFAULT_ALL_CURVES_FILTERS: + self.setFileFilter( + dataKind='curves', nameFilter=nameFilter, func=self._saveCurves) + + for nameFilter in self.DEFAULT_IMAGE_FILTERS: + self.setFileFilter( + dataKind='image', nameFilter=nameFilter, func=self._saveImage) + + for nameFilter in self.DEFAULT_SCATTER_FILTERS: + self.setFileFilter( + dataKind='scatter', nameFilter=nameFilter, func=self._saveScatter) + super(SaveAction, self).__init__( plot, icon='document-save', text='Save as...', tooltip='Save curve/image/plot snapshot dialog', @@ -148,7 +196,7 @@ class SaveAction(PlotAction): msg.setDetailedText(traceback.format_exc()) msg.exec_() - def _saveSnapshot(self, filename, nameFilter): + def _saveSnapshot(self, plot, filename, nameFilter): """Save a snapshot of the :class:`PlotWindow` widget. :param str filename: The name of the file to write @@ -165,10 +213,51 @@ class SaveAction(PlotAction): 'Saving plot snapshot failed: format not supported') return False - self.plot.saveGraph(filename, fileFormat=fileFormat) + plot.saveGraph(filename, fileFormat=fileFormat) return True - def _saveCurve(self, filename, nameFilter): + def _getAxesLabels(self, item): + # If curve has no associated label, get the default from the plot + xlabel = item.getXLabel() or self.plot.getXAxis().getLabel() + ylabel = item.getYLabel() or self.plot.getYAxis().getLabel() + return xlabel, ylabel + + def _selectWriteableOutputGroup(self, filename): + if os.path.exists(filename) and os.path.isfile(filename) \ + and os.access(filename, os.W_OK): + entryPath = selectOutputGroup(filename) + if entryPath is None: + _logger.info("Save operation cancelled") + return None + return entryPath + elif not os.path.exists(filename): + # create new entry in new file + return "/entry" + else: + self._errorMessage('Save failed (file access issue)\n') + return None + + def _saveCurveAsNXdata(self, curve, filename): + entryPath = self._selectWriteableOutputGroup(filename) + if entryPath is None: + return False + + xlabel, ylabel = self._getAxesLabels(curve) + + return save_NXdata( + filename, + nxentry_name=entryPath, + signal=curve.getYData(copy=False), + axes=[curve.getXData(copy=False)], + signal_name="y", + axes_names=["x"], + signal_long_name=ylabel, + axes_long_names=[xlabel], + signal_errors=curve.getYErrorData(copy=False), + axes_errors=[curve.getXErrorData(copy=True)], + title=self.plot.getGraphTitle()) + + def _saveCurve(self, plot, filename, nameFilter): """Save a curve from the plot. :param str filename: The name of the file to write @@ -176,15 +265,15 @@ class SaveAction(PlotAction): :return: False if format is not supported or save failed, True otherwise. """ - if nameFilter not in self.CURVE_FILTERS: + if nameFilter not in self.DEFAULT_CURVE_FILTERS: return False # Check if a curve is to be saved - curve = self.plot.getActiveCurve() + curve = plot.getActiveCurve() # before calling _saveCurve, if there is no selected curve, we # make sure there is only one curve on the graph if curve is None: - curves = self.plot.getAllCurves() + curves = plot.getAllCurves() if not curves: self._errorMessage("No curve to be saved") return False @@ -199,26 +288,10 @@ class SaveAction(PlotAction): # .npy or nxdata fmt, csvdelim, autoheader = ("", "", False) - # If curve has no associated label, get the default from the plot - xlabel = curve.getXLabel() - if xlabel is None: - xlabel = self.plot.getXAxis().getLabel() - ylabel = curve.getYLabel() - if ylabel is None: - ylabel = self.plot.getYAxis().getLabel() + xlabel, ylabel = self._getAxesLabels(curve) if nameFilter == self.CURVE_FILTER_NXDATA: - return save_NXdata( - filename, - signal=curve.getYData(copy=False), - axes=[curve.getXData(copy=False)], - signal_name="y", - axes_names=["x"], - signal_long_name=ylabel, - axes_long_names=[xlabel], - signal_errors=curve.getYErrorData(copy=False), - axes_errors=[curve.getXErrorData(copy=True)], - title=self.plot.getGraphTitle()) + return self._saveCurveAsNXdata(curve, filename) try: save1D(filename, @@ -233,7 +306,7 @@ class SaveAction(PlotAction): return True - def _saveCurves(self, filename, nameFilter): + def _saveCurves(self, plot, filename, nameFilter): """Save all curves from the plot. :param str filename: The name of the file to write @@ -241,10 +314,10 @@ class SaveAction(PlotAction): :return: False if format is not supported or save failed, True otherwise. """ - if nameFilter not in self.ALL_CURVES_FILTERS: + if nameFilter not in self.DEFAULT_ALL_CURVES_FILTERS: return False - curves = self.plot.getAllCurves() + curves = plot.getAllCurves() if not curves: self._errorMessage("No curves to be saved") return False @@ -252,8 +325,8 @@ class SaveAction(PlotAction): curve = curves[0] scanno = 1 try: - xlabel = curve.getXLabel() or self.plot.getGraphXLabel() - ylabel = curve.getYLabel() or self.plot.getGraphYLabel(curve.getYAxis()) + xlabel = curve.getXLabel() or plot.getGraphXLabel() + ylabel = curve.getYLabel() or plot.getGraphYLabel(curve.getYAxis()) specfile = savespec(filename, curve.getXData(copy=False), curve.getYData(copy=False), @@ -269,8 +342,8 @@ class SaveAction(PlotAction): for curve in curves[1:]: try: scanno += 1 - xlabel = curve.getXLabel() or self.plot.getGraphXLabel() - ylabel = curve.getYLabel() or self.plot.getGraphYLabel(curve.getYAxis()) + xlabel = curve.getXLabel() or plot.getGraphXLabel() + ylabel = curve.getYLabel() or plot.getGraphYLabel(curve.getYAxis()) specfile = savespec(specfile, curve.getXData(copy=False), curve.getYData(copy=False), @@ -286,7 +359,7 @@ class SaveAction(PlotAction): return True - def _saveImage(self, filename, nameFilter): + def _saveImage(self, plot, filename, nameFilter): """Save an image from the plot. :param str filename: The name of the file to write @@ -294,13 +367,13 @@ class SaveAction(PlotAction): :return: False if format is not supported or save failed, True otherwise. """ - if nameFilter not in self.IMAGE_FILTERS: + if nameFilter not in self.DEFAULT_IMAGE_FILTERS: return False - image = self.plot.getActiveImage() + image = plot.getActiveImage() if image is None: qt.QMessageBox.warning( - self.plot, "No Data", "No image to be saved") + plot, "No Data", "No image to be saved") return False data = image.getData(copy=False) @@ -325,21 +398,24 @@ class SaveAction(PlotAction): return True elif nameFilter == self.IMAGE_FILTER_NXDATA: + entryPath = self._selectWriteableOutputGroup(filename) + if entryPath is None: + return False xorigin, yorigin = image.getOrigin() xscale, yscale = image.getScale() xaxis = xorigin + xscale * numpy.arange(data.shape[1]) yaxis = yorigin + yscale * numpy.arange(data.shape[0]) - xlabel = image.getXLabel() or self.plot.getGraphXLabel() - ylabel = image.getYLabel() or self.plot.getGraphYLabel() + xlabel, ylabel = self._getAxesLabels(image) interpretation = "image" if len(data.shape) == 2 else "rgba-image" return save_NXdata(filename, + nxentry_name=entryPath, signal=data, axes=[yaxis, xaxis], signal_name="image", axes_names=["y", "x"], axes_long_names=[ylabel, xlabel], - title=self.plot.getGraphTitle(), + title=plot.getGraphTitle(), interpretation=interpretation) elif nameFilter in (self.IMAGE_FILTER_ASCII, @@ -368,19 +444,13 @@ class SaveAction(PlotAction): return False return True - elif nameFilter in (self.IMAGE_FILTER_RGB_PNG, - self.IMAGE_FILTER_RGB_TIFF): + elif nameFilter == self.IMAGE_FILTER_RGB_PNG: # Get displayed image rgbaImage = image.getRgbaImageData(copy=False) # Convert RGB QImage qimage = convertArrayToQImage(rgbaImage[:, :, :3]) - if nameFilter == self.IMAGE_FILTER_RGB_PNG: - fileFormat = 'PNG' - else: - fileFormat = 'TIFF' - - if qimage.save(filename, fileFormat): + if qimage.save(filename, 'PNG'): return True else: _logger.error('Failed to save image as %s', filename) @@ -391,7 +461,7 @@ class SaveAction(PlotAction): return False - def _saveScatter(self, filename, nameFilter): + def _saveScatter(self, plot, filename, nameFilter): """Save an image from the plot. :param str filename: The name of the file to write @@ -399,12 +469,15 @@ class SaveAction(PlotAction): :return: False if format is not supported or save failed, True otherwise. """ - if nameFilter not in self.SCATTER_FILTERS: + if nameFilter not in self.DEFAULT_SCATTER_FILTERS: return False if nameFilter == self.SCATTER_FILTER_NXDATA: - scatter = self.plot.getScatter() - # TODO: we could get all scatters on this plot and concatenate their (x, y, values) + entryPath = self._selectWriteableOutputGroup(filename) + if entryPath is None: + return False + scatter = plot.getScatter() + x = scatter.getXData(copy=False) y = scatter.getYData(copy=False) z = scatter.getValueData(copy=False) @@ -417,51 +490,92 @@ class SaveAction(PlotAction): if isinstance(yerror, float): yerror = yerror * numpy.ones(x.shape, dtype=numpy.float32) - xlabel = self.plot.getGraphXLabel() - ylabel = self.plot.getGraphYLabel() + xlabel = plot.getGraphXLabel() + ylabel = plot.getGraphYLabel() return save_NXdata( filename, + nxentry_name=entryPath, signal=z, axes=[x, y], signal_name="values", axes_names=["x", "y"], axes_long_names=[xlabel, ylabel], axes_errors=[xerror, yerror], - title=self.plot.getGraphTitle()) + title=plot.getGraphTitle()) + + def setFileFilter(self, dataKind, nameFilter, func): + """Set a name filter to add/replace a file format support + + :param str dataKind: + The kind of data for which the provided filter is valid. + One of: 'all', 'curve', 'curves', 'image', 'scatter' + :param str nameFilter: The name filter in the QFileDialog. + See :meth:`QFileDialog.setNameFilters`. + :param callable func: The function to call to perform saving. + Expected signature is: + bool func(PlotWidget plot, str filename, str nameFilter) + """ + assert dataKind in ('all', 'curve', 'curves', 'image', 'scatter') + + self._filters[dataKind][nameFilter] = func + + def getFileFilters(self, dataKind): + """Returns the nameFilter and associated function for a kind of data. + + :param str dataKind: + The kind of data for which the provided filter is valid. + On of: 'all', 'curve', 'curves', 'image', 'scatter' + :return: {nameFilter: function} associations. + :rtype: collections.OrderedDict + """ + assert dataKind in ('all', 'curve', 'curves', 'image', 'scatter') + + return self._filters[dataKind].copy() def _actionTriggered(self, checked=False): """Handle save action.""" # Set-up filters - filters = [] + filters = OrderedDict() # Add image filters if there is an active image if self.plot.getActiveImage() is not None: - filters.extend(self.IMAGE_FILTERS) + filters.update(self._filters['image'].items()) # Add curve filters if there is a curve to save if (self.plot.getActiveCurve() is not None or len(self.plot.getAllCurves()) == 1): - filters.extend(self.CURVE_FILTERS) + filters.update(self._filters['curve'].items()) if len(self.plot.getAllCurves()) > 1: - filters.extend(self.ALL_CURVES_FILTERS) + filters.update(self._filters['curves'].items()) # Add scatter filters if there is a scatter # todo: CSV if self.plot.getScatter() is not None: - filters.extend(self.SCATTER_FILTERS) + filters.update(self._filters['scatter'].items()) - filters.extend(self.SNAPSHOT_FILTERS) + filters.update(self._filters['all'].items()) # Create and run File dialog dialog = qt.QFileDialog(self.plot) + dialog.setOption(dialog.DontUseNativeDialog) dialog.setWindowTitle("Output File Selection") dialog.setModal(1) - dialog.setNameFilters(filters) + dialog.setNameFilters(list(filters.keys())) dialog.setFileMode(dialog.AnyFile) dialog.setAcceptMode(dialog.AcceptSave) + def onFilterSelection(filt_): + # disable overwrite confirmation for NXdata types, + # because we append the data to existing files + if filt_ in self.DEFAULT_APPEND_FILTERS: + dialog.setOption(dialog.DontConfirmOverwrite) + else: + dialog.setOption(dialog.DontConfirmOverwrite, False) + + dialog.filterSelected.connect(onFilterSelection) + if not dialog.exec_(): return False @@ -469,34 +583,25 @@ class SaveAction(PlotAction): filename = dialog.selectedFiles()[0] dialog.close() - # Forces the filename extension to match the chosen filter - if "NXdata" in nameFilter: - has_allowed_ext = False - for ext in _NEXUS_HDF5_EXT: + if '(' in nameFilter and ')' == nameFilter.strip()[-1]: + # Check for correct file extension + # Extract file extensions as .something + extensions = [ext[ext.find('.'):] for ext in + nameFilter[nameFilter.find('(')+1:-1].split()] + for ext in extensions: if (len(filename) > len(ext) and filename[-len(ext):].lower() == ext.lower()): - has_allowed_ext = True - if not has_allowed_ext: - filename += ".h5" - else: - default_extension = nameFilter.split()[-1][2:-1] - if (len(filename) <= len(default_extension) or - filename[-len(default_extension):].lower() != default_extension.lower()): - filename += default_extension + break + else: # filename has no extension supported in nameFilter, add one + if len(extensions) >= 1: + filename += extensions[0] # Handle save - if nameFilter in self.SNAPSHOT_FILTERS: - return self._saveSnapshot(filename, nameFilter) - elif nameFilter in self.CURVE_FILTERS: - return self._saveCurve(filename, nameFilter) - elif nameFilter in self.ALL_CURVES_FILTERS: - return self._saveCurves(filename, nameFilter) - elif nameFilter in self.IMAGE_FILTERS: - return self._saveImage(filename, nameFilter) - elif nameFilter in self.SCATTER_FILTERS: - return self._saveScatter(filename, nameFilter) + func = filters.get(nameFilter, None) + if func is not None: + return func(self.plot, filename, nameFilter) else: - _logger.warning('Unsupported file filter: %s', nameFilter) + _logger.error('Unsupported file filter: %s', nameFilter) return False @@ -526,9 +631,6 @@ class PrintAction(PlotAction): :param parent: See :class:`QAction`. """ - # Share QPrinter instance to propose latest used as default - _printer = None - def __init__(self, plot, parent=None): super(PrintAction, self).__init__( plot, icon='document-print', text='Print...', @@ -538,15 +640,17 @@ class PrintAction(PlotAction): self.setShortcut(qt.QKeySequence.Print) self.setShortcutContext(qt.Qt.WidgetShortcut) - @property - def printer(self): - """The QPrinter instance used by the actions. + def getPrinter(self): + """The QPrinter instance used by the PrintAction. - This is shared accross all instances of PrintAct + :rtype: QPrinter """ - if self._printer is None: - PrintAction._printer = qt.QPrinter() - return self._printer + return printer.getDefaultPrinter() + + @property + @deprecated(replacement="getPrinter()", since_version="0.8.0") + def printer(self): + return self.getPrinter() def printPlotAsWidget(self): """Open the print dialog and print the plot. @@ -555,7 +659,7 @@ class PrintAction(PlotAction): :return: True if successful """ - dialog = qt.QPrintDialog(self.printer, self.plot) + dialog = qt.QPrintDialog(self.getPrinter(), self.plot) dialog.setWindowTitle('Print Plot') if not dialog.exec_(): return False @@ -564,10 +668,10 @@ class PrintAction(PlotAction): widget = self.plot.centralWidget() painter = qt.QPainter() - if not painter.begin(self.printer): + if not painter.begin(self.getPrinter()): return False - pageRect = self.printer.pageRect() + pageRect = self.getPrinter().pageRect() xScale = pageRect.width() / widget.width() yScale = pageRect.height() / widget.height() scale = min(xScale, yScale) @@ -588,7 +692,7 @@ class PrintAction(PlotAction): :return: True if successful """ # Init printer and start printer dialog - dialog = qt.QPrintDialog(self.printer, self.plot) + dialog = qt.QPrintDialog(self.getPrinter(), self.plot) dialog.setWindowTitle('Print Plot') if not dialog.exec_(): return False @@ -599,13 +703,13 @@ class PrintAction(PlotAction): pixmap = qt.QPixmap() pixmap.loadFromData(pngData, 'png') - xScale = self.printer.pageRect().width() / pixmap.width() - yScale = self.printer.pageRect().height() / pixmap.height() + xScale = self.getPrinter().pageRect().width() / pixmap.width() + yScale = self.getPrinter().pageRect().height() / pixmap.height() scale = min(xScale, yScale) # Draw pixmap with painter painter = qt.QPainter() - if not painter.begin(self.printer): + if not painter.begin(self.getPrinter()): return False painter.drawPixmap(0, 0, |