diff options
Diffstat (limited to 'silx/sx/_plot.py')
-rw-r--r-- | silx/sx/_plot.py | 40 |
1 files changed, 23 insertions, 17 deletions
diff --git a/silx/sx/_plot.py b/silx/sx/_plot.py index d434fec..00dcabe 100644 --- a/silx/sx/_plot.py +++ b/silx/sx/_plot.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -27,7 +27,7 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "28/06/2018" +__date__ = "06/11/2018" import collections @@ -35,16 +35,15 @@ import logging import weakref import numpy +import six from ..utils.weakref import WeakList from ..gui import qt from ..gui.plot import Plot1D, Plot2D, ScatterView -from ..gui.colors import COLORDICT -from ..gui.colors import Colormap +from ..gui import colors from ..gui.plot.tools import roi from ..gui.plot.items import roi as roi_items from ..gui.plot.tools.toolbars import InteractiveModeToolBar -from silx.third_party import six _logger = logging.getLogger(__name__) @@ -165,7 +164,7 @@ def plot(*args, **kwargs): # Parse style if style: # Handle color first - possible_colors = [c for c in COLORDICT if style.startswith(c)] + possible_colors = [c for c in colors.COLORDICT if style.startswith(c)] if possible_colors: # Take the longest string matching a color name curve_color = possible_colors[0] for c in possible_colors[1:]: @@ -203,7 +202,7 @@ def plot(*args, **kwargs): return plt -def imshow(data=None, cmap=None, norm=Colormap.LINEAR, +def imshow(data=None, cmap=None, norm=colors.Colormap.LINEAR, vmin=None, vmax=None, aspect=False, origin='upper', scale=(1., 1.), @@ -231,7 +230,9 @@ def imshow(data=None, cmap=None, norm=Colormap.LINEAR, :param data: data to plot as an image :type data: numpy.ndarray-like with 2 dimensions - :param str cmap: The name of the colormap to use for the plot. + :param str cmap: The name of the colormap to use for the plot. It also + supports a numpy array containing a RGB LUT, or a `colors.Colormap` + instance. :param str norm: The normalization of the colormap: 'linear' (default) or 'log' :param float vmin: The value to use for the min of the colormap @@ -257,13 +258,17 @@ def imshow(data=None, cmap=None, norm=Colormap.LINEAR, # Update default colormap with input parameters colormap = plt.getDefaultColormap() - if cmap is not None: + if isinstance(cmap, colors.Colormap): + colormap = cmap + plt.setDefaultColormap(colormap) + elif isinstance(cmap, numpy.ndarray): + colormap.setColors(cmap) + elif cmap is not None: colormap.setName(cmap) - assert norm in Colormap.NORMALIZATIONS + assert norm in colors.Colormap.NORMALIZATIONS colormap.setNormalization(norm) colormap.setVMin(vmin) colormap.setVMax(vmax) - plt.setDefaultColormap(colormap) # Handle aspect if aspect in (None, False, 'auto', 'normal'): @@ -294,8 +299,8 @@ def imshow(data=None, cmap=None, norm=Colormap.LINEAR, def scatter(x=None, y=None, value=None, size=None, - marker='o', - cmap=None, norm=Colormap.LINEAR, + marker=None, + cmap=None, norm=colors.Colormap.LINEAR, vmin=None, vmax=None): """ Plot scattered data in a :class:`~silx.gui.plot.ScatterView` widget. @@ -328,8 +333,8 @@ def scatter(x=None, y=None, value=None, size=None, :param numpy.ndarray y: 1D array-like of y coordinates :param numpy.ndarray value: 1D array-like of data values :param float size: Size^2 of the markers - :param str marker: Symbol used to represent the points (default: 'o') - :param str cmap: The name of the colormap to use for the plot. + :param str marker: Symbol used to represent the points + :param str cmap: The name of the colormap to use for the plot :param str norm: The normalization of the colormap: 'linear' (default) or 'log' :param float vmin: The value to use for the min of the colormap @@ -343,7 +348,7 @@ def scatter(x=None, y=None, value=None, size=None, colormap = plt.getPlotWidget().getDefaultColormap() if cmap is not None: colormap.setName(cmap) - assert norm in Colormap.NORMALIZATIONS + assert norm in colors.Colormap.NORMALIZATIONS colormap.setNormalization(norm) colormap.setVMin(vmin) colormap.setVMax(vmax) @@ -366,7 +371,8 @@ def scatter(x=None, y=None, value=None, size=None, plt.setData(x, y, value) item = plt.getScatterItem() - item.setSymbol(marker) + if marker is not None: + item.setSymbol(marker) if size is not None: item.setSymbolSize(numpy.sqrt(size)) |