summaryrefslogtreecommitdiff
path: root/silx/sx/_plot.py
diff options
context:
space:
mode:
Diffstat (limited to 'silx/sx/_plot.py')
-rw-r--r--silx/sx/_plot.py40
1 files changed, 23 insertions, 17 deletions
diff --git a/silx/sx/_plot.py b/silx/sx/_plot.py
index d434fec..00dcabe 100644
--- a/silx/sx/_plot.py
+++ b/silx/sx/_plot.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -27,7 +27,7 @@
__authors__ = ["T. Vincent"]
__license__ = "MIT"
-__date__ = "28/06/2018"
+__date__ = "06/11/2018"
import collections
@@ -35,16 +35,15 @@ import logging
import weakref
import numpy
+import six
from ..utils.weakref import WeakList
from ..gui import qt
from ..gui.plot import Plot1D, Plot2D, ScatterView
-from ..gui.colors import COLORDICT
-from ..gui.colors import Colormap
+from ..gui import colors
from ..gui.plot.tools import roi
from ..gui.plot.items import roi as roi_items
from ..gui.plot.tools.toolbars import InteractiveModeToolBar
-from silx.third_party import six
_logger = logging.getLogger(__name__)
@@ -165,7 +164,7 @@ def plot(*args, **kwargs):
# Parse style
if style:
# Handle color first
- possible_colors = [c for c in COLORDICT if style.startswith(c)]
+ possible_colors = [c for c in colors.COLORDICT if style.startswith(c)]
if possible_colors: # Take the longest string matching a color name
curve_color = possible_colors[0]
for c in possible_colors[1:]:
@@ -203,7 +202,7 @@ def plot(*args, **kwargs):
return plt
-def imshow(data=None, cmap=None, norm=Colormap.LINEAR,
+def imshow(data=None, cmap=None, norm=colors.Colormap.LINEAR,
vmin=None, vmax=None,
aspect=False,
origin='upper', scale=(1., 1.),
@@ -231,7 +230,9 @@ def imshow(data=None, cmap=None, norm=Colormap.LINEAR,
:param data: data to plot as an image
:type data: numpy.ndarray-like with 2 dimensions
- :param str cmap: The name of the colormap to use for the plot.
+ :param str cmap: The name of the colormap to use for the plot. It also
+ supports a numpy array containing a RGB LUT, or a `colors.Colormap`
+ instance.
:param str norm: The normalization of the colormap:
'linear' (default) or 'log'
:param float vmin: The value to use for the min of the colormap
@@ -257,13 +258,17 @@ def imshow(data=None, cmap=None, norm=Colormap.LINEAR,
# Update default colormap with input parameters
colormap = plt.getDefaultColormap()
- if cmap is not None:
+ if isinstance(cmap, colors.Colormap):
+ colormap = cmap
+ plt.setDefaultColormap(colormap)
+ elif isinstance(cmap, numpy.ndarray):
+ colormap.setColors(cmap)
+ elif cmap is not None:
colormap.setName(cmap)
- assert norm in Colormap.NORMALIZATIONS
+ assert norm in colors.Colormap.NORMALIZATIONS
colormap.setNormalization(norm)
colormap.setVMin(vmin)
colormap.setVMax(vmax)
- plt.setDefaultColormap(colormap)
# Handle aspect
if aspect in (None, False, 'auto', 'normal'):
@@ -294,8 +299,8 @@ def imshow(data=None, cmap=None, norm=Colormap.LINEAR,
def scatter(x=None, y=None, value=None, size=None,
- marker='o',
- cmap=None, norm=Colormap.LINEAR,
+ marker=None,
+ cmap=None, norm=colors.Colormap.LINEAR,
vmin=None, vmax=None):
"""
Plot scattered data in a :class:`~silx.gui.plot.ScatterView` widget.
@@ -328,8 +333,8 @@ def scatter(x=None, y=None, value=None, size=None,
:param numpy.ndarray y: 1D array-like of y coordinates
:param numpy.ndarray value: 1D array-like of data values
:param float size: Size^2 of the markers
- :param str marker: Symbol used to represent the points (default: 'o')
- :param str cmap: The name of the colormap to use for the plot.
+ :param str marker: Symbol used to represent the points
+ :param str cmap: The name of the colormap to use for the plot
:param str norm: The normalization of the colormap:
'linear' (default) or 'log'
:param float vmin: The value to use for the min of the colormap
@@ -343,7 +348,7 @@ def scatter(x=None, y=None, value=None, size=None,
colormap = plt.getPlotWidget().getDefaultColormap()
if cmap is not None:
colormap.setName(cmap)
- assert norm in Colormap.NORMALIZATIONS
+ assert norm in colors.Colormap.NORMALIZATIONS
colormap.setNormalization(norm)
colormap.setVMin(vmin)
colormap.setVMax(vmax)
@@ -366,7 +371,8 @@ def scatter(x=None, y=None, value=None, size=None,
plt.setData(x, y, value)
item = plt.getScatterItem()
- item.setSymbol(marker)
+ if marker is not None:
+ item.setSymbol(marker)
if size is not None:
item.setSymbolSize(numpy.sqrt(size))