summaryrefslogtreecommitdiff
path: root/silx/sx
diff options
context:
space:
mode:
Diffstat (limited to 'silx/sx')
-rw-r--r--silx/sx/__init__.py70
-rw-r--r--silx/sx/_plot.py373
-rw-r--r--silx/sx/_plot3d.py8
-rw-r--r--silx/sx/test/__init__.py38
-rw-r--r--silx/sx/test/test_sx.py288
5 files changed, 591 insertions, 186 deletions
diff --git a/silx/sx/__init__.py b/silx/sx/__init__.py
index bdec6e6..e3641c8 100644
--- a/silx/sx/__init__.py
+++ b/silx/sx/__init__.py
@@ -28,11 +28,14 @@ It loads the main features of silx and provides high-level functions.
>>> from silx import sx
When used in an interpreter is sets-up Qt and loads some silx widgets.
-When used in a `jupyter <https://jupyter.org/>`_ /
-`IPython <https://ipython.org/>`_ notebook, neither Qt nor silx widgets are loaded.
+In a `jupyter <https://jupyter.org/>`_ / `IPython <https://ipython.org/>`_
+notebook, to set-up Qt and loads silx widgets, you must then call:
+
+>>> sx.enable_gui()
When used in `IPython <https://ipython.org/>`_, it also runs ``%pylab``,
-thus importing `numpy <http://www.numpy.org/>`_ and `matplotlib <https://matplotlib.org/>`_.
+thus importing `numpy <http://www.numpy.org/>`_ and
+`matplotlib <https://matplotlib.org/>`_.
"""
@@ -43,6 +46,7 @@ __date__ = "16/01/2017"
import logging as _logging
import sys as _sys
+import os as _os
_logger = _logging.getLogger(__name__)
@@ -52,6 +56,9 @@ _logger = _logging.getLogger(__name__)
if hasattr(_sys, 'ps1'):
_logging.basicConfig()
+# Probe DISPLAY available on linux
+_NO_DISPLAY = _sys.platform.startswith('linux') and not _os.environ.get('DISPLAY')
+
# Probe ipython
try:
from IPython import get_ipython as _get_ipython
@@ -68,46 +75,61 @@ else:
_IS_NOTEBOOK = False
-# Load Qt and widgets only if running from console
-if _IS_NOTEBOOK:
- _logger.warning(
- 'Not loading silx.gui features: Running from the notebook')
+def enable_gui():
+ """Populate silx.sx module with silx.gui features and initialise Qt"""
+ if _NO_DISPLAY: # Missing DISPLAY under linux
+ _logger.warning(
+ 'Not loading silx.gui features: No DISPLAY available')
+ return
-else:
- from silx.gui import qt
+ global qt, qapp
+
+ if _IS_NOTEBOOK:
+ _get_ipython().enable_pylab(gui='qt', import_all=False)
- if hasattr(_sys, 'ps1'): # If from console, make sure QApplication runs
- qapp = qt.QApplication.instance() or qt.QApplication([])
+ from silx.gui import qt
+ qapp = qt.QApplication.instance() or qt.QApplication([])
+ if hasattr(_sys, 'ps1'): # If from console, change windows icon
# Change windows default icon
- from silx.gui import icons as _icons
- qapp.setWindowIcon(_icons.getQIcon('silx'))
- del _icons # clean-up namespace
+ from silx.gui import icons
+ qapp.setWindowIcon(icons.getQIcon('silx'))
+
+ global ImageView, PlotWidget, PlotWindow, Plot1D
+ global Plot2D, StackView, ScatterView, TickMode
+ from silx.gui.plot import (ImageView, PlotWidget, PlotWindow, Plot1D,
+ Plot2D, StackView, ScatterView, TickMode) # noqa
- from silx.gui.plot import * # noqa
- from ._plot import plot, imshow, ginput # noqa
+ global plot, imshow, scatter, ginput
+ from ._plot import plot, imshow, scatter, ginput # noqa
try:
- import OpenGL as _OpenGL
+ import OpenGL
except ImportError:
_logger.warning(
'Not loading silx.gui.plot3d features: PyOpenGL is not installed')
else:
- del _OpenGL # clean-up namespace
+ global contour3d, points3d
from ._plot3d import contour3d, points3d # noqa
+# Load Qt and widgets only if running from console and display available
+if _IS_NOTEBOOK:
+ _logger.warning(
+ 'Not loading silx.gui features: Running from the notebook')
+else:
+ enable_gui()
+
+
# %pylab
if _get_ipython is not None and _get_ipython() is not None:
- _get_ipython().enable_pylab(gui='inline' if _IS_NOTEBOOK else 'qt',
- import_all=False)
+ if not _NO_DISPLAY: # Not loading pylab without display
+ from IPython.core.pylabtools import import_pylab as _import_pylab
+ _import_pylab(_get_ipython().user_ns, import_all=False)
# Clean-up
-del _sys
-del _get_ipython
-del _IS_NOTEBOOK
-
+del _os
# Load some silx stuff in namespace
from silx import version # noqa
diff --git a/silx/sx/_plot.py b/silx/sx/_plot.py
index dfc24d9..d434fec 100644
--- a/silx/sx/_plot.py
+++ b/silx/sx/_plot.py
@@ -27,21 +27,23 @@
__authors__ = ["T. Vincent"]
__license__ = "MIT"
-__date__ = "26/02/2018"
+__date__ = "28/06/2018"
import collections
import logging
-import time
import weakref
import numpy
from ..utils.weakref import WeakList
from ..gui import qt
-from ..gui.plot import Plot1D, Plot2D, PlotWidget
-from ..gui.plot.Colors import COLORDICT
-from ..gui.plot.Colormap import Colormap
+from ..gui.plot import Plot1D, Plot2D, ScatterView
+from ..gui.colors import COLORDICT
+from ..gui.colors import Colormap
+from ..gui.plot.tools import roi
+from ..gui.plot.items import roi as roi_items
+from ..gui.plot.tools.toolbars import InteractiveModeToolBar
from silx.third_party import six
@@ -111,6 +113,8 @@ def plot(*args, **kwargs):
:param str title: The title of the Plot widget (default: None)
:param str xlabel: The label of the X axis (default: None)
:param str ylabel: The label of the Y axis (default: None)
+ :return: The widget plotting the curve(s)
+ :rtype: silx.gui.plot.Plot1D
"""
plt = Plot1D()
if 'title' in kwargs:
@@ -202,7 +206,7 @@ def plot(*args, **kwargs):
def imshow(data=None, cmap=None, norm=Colormap.LINEAR,
vmin=None, vmax=None,
aspect=False,
- origin=(0., 0.), scale=(1., 1.),
+ origin='upper', scale=(1., 1.),
title='', xlabel='X', ylabel='Y'):
"""
Plot an image in a :class:`~silx.gui.plot.PlotWindow.Plot2D` widget.
@@ -215,6 +219,12 @@ def imshow(data=None, cmap=None, norm=Colormap.LINEAR,
>>> data = numpy.random.random(1024 * 1024).reshape(1024, 1024)
>>> plt = sx.imshow(data, title='Random data')
+ By default, the image origin is displayed in the upper left
+ corner of the plot. To invert the Y axis, and place the image origin
+ in the lower left corner of the plot, use the *origin* parameter:
+
+ >>> plt = sx.imshow(data, origin='lower')
+
This function supports a subset of `matplotlib.pyplot.imshow
<http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.imshow>`_
arguments.
@@ -227,14 +237,18 @@ def imshow(data=None, cmap=None, norm=Colormap.LINEAR,
:param float vmin: The value to use for the min of the colormap
:param float vmax: The value to use for the max of the colormap
:param bool aspect: True to keep aspect ratio (Default: False)
- :param origin: (ox, oy) The coordinates of the image origin in the plot
- :type origin: 2-tuple of floats
+ :param origin: Either image origin as the Y axis orientation:
+ 'upper' (default) or 'lower'
+ or the coordinates (ox, oy) of the image origin in the plot.
+ :type origin: str or 2-tuple of floats
:param scale: (sx, sy) The scale of the image in the plot
(i.e., the size of the image's pixel in plot coordinates)
:type scale: 2-tuple of floats
:param str title: The title of the Plot widget
:param str xlabel: The label of the X axis
:param str ylabel: The label of the Y axis
+ :return: The widget plotting the image
+ :rtype: silx.gui.plot.Plot2D
"""
plt = Plot2D()
plt.setGraphTitle(title)
@@ -260,6 +274,11 @@ def imshow(data=None, cmap=None, norm=Colormap.LINEAR,
_logger.warning(
'imshow: Unhandled aspect argument: %s', str(aspect))
+ # Handle matplotlib-like origin
+ if origin in ('upper', 'lower'):
+ plt.setYAxisInverted(origin == 'upper')
+ origin = 0., 0. # Set origin to the definition of silx
+
if data is not None:
data = numpy.array(data, copy=True)
@@ -274,6 +293,90 @@ def imshow(data=None, cmap=None, norm=Colormap.LINEAR,
return plt
+def scatter(x=None, y=None, value=None, size=None,
+ marker='o',
+ cmap=None, norm=Colormap.LINEAR,
+ vmin=None, vmax=None):
+ """
+ Plot scattered data in a :class:`~silx.gui.plot.ScatterView` widget.
+
+ How to use:
+
+ >>> from silx import sx
+ >>> import numpy
+
+ >>> x = numpy.random.random(100)
+ >>> y = numpy.random.random(100)
+ >>> values = numpy.random.random(100)
+ >>> plt = sx.scatter(x, y, values, cmap='viridis')
+
+ Supported symbols:
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+
+ This function supports a subset of `matplotlib.pyplot.scatter
+ <http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.scatter>`_
+ arguments.
+
+ :param numpy.ndarray x: 1D array-like of x coordinates
+ :param numpy.ndarray y: 1D array-like of y coordinates
+ :param numpy.ndarray value: 1D array-like of data values
+ :param float size: Size^2 of the markers
+ :param str marker: Symbol used to represent the points (default: 'o')
+ :param str cmap: The name of the colormap to use for the plot.
+ :param str norm: The normalization of the colormap:
+ 'linear' (default) or 'log'
+ :param float vmin: The value to use for the min of the colormap
+ :param float vmax: The value to use for the max of the colormap
+ :return: The widget plotting the scatter plot
+ :rtype: silx.gui.plot.ScatterView.ScatterView
+ """
+ plt = ScatterView()
+
+ # Update default colormap with input parameters
+ colormap = plt.getPlotWidget().getDefaultColormap()
+ if cmap is not None:
+ colormap.setName(cmap)
+ assert norm in Colormap.NORMALIZATIONS
+ colormap.setNormalization(norm)
+ colormap.setVMin(vmin)
+ colormap.setVMax(vmax)
+ plt.getPlotWidget().setDefaultColormap(colormap)
+
+ if x is not None and y is not None: # Add a scatter plot
+ x = numpy.array(x, copy=True).reshape(-1)
+ y = numpy.array(y, copy=True).reshape(-1)
+ assert len(x) == len(y)
+
+ if value is None:
+ value = numpy.ones(len(x), dtype=numpy.float32)
+
+ elif isinstance(value, collections.Iterable):
+ value = numpy.array(value, copy=True).reshape(-1)
+ assert len(x) == len(value)
+
+ else:
+ value = numpy.ones(len(x), dtype=numpy.float) * value
+
+ plt.setData(x, y, value)
+ item = plt.getScatterItem()
+ item.setSymbol(marker)
+ if size is not None:
+ item.setSymbolSize(numpy.sqrt(size))
+
+ plt.resetZoom()
+
+ plt.show()
+ _plots.insert(0, plt.getPlotWidget())
+ return plt
+
+
class _GInputResult(tuple):
"""Object storing :func:`ginput` result
@@ -328,175 +431,128 @@ class _GInputResult(tuple):
return self._data
-class _GInputHandler(qt.QEventLoop):
+class _GInputHandler(roi.InteractiveRegionOfInterestManager):
"""Implements :func:`ginput`
:param PlotWidget plot:
- :param int n:
- :param float timeout:
+ :param int n: Max number of points to request
+ :param float timeout: Timeout in seconds
"""
def __init__(self, plot, n, timeout):
- super(_GInputHandler, self).__init__()
-
- if not isinstance(plot, PlotWidget):
- raise ValueError('plot is not a PlotWidget: %s', plot)
+ super(_GInputHandler, self).__init__(plot)
- self._plot = plot
self._timeout = timeout
- self._markersAndResult = []
- self._totalPoints = n
- self._endTime = 0.
+ self.__selections = collections.OrderedDict()
- def eventFilter(self, obj, event):
- """Event filter for plot hide event"""
- if event.type() == qt.QEvent.Hide:
- self.quit()
+ window = plot.window() # Retrieve window containing PlotWidget
+ statusBar = window.statusBar()
+ self.sigMessageChanged.connect(statusBar.showMessage)
+ self.setMaxRois(n)
+ self.setValidationMode(self.ValidationMode.AUTO_ENTER)
+ self.sigRoiAdded.connect(self.__added)
+ self.sigRoiAboutToBeRemoved.connect(self.__removed)
- elif event.type() == qt.QEvent.KeyPress:
- if event.key() in (qt.Qt.Key_Delete, qt.Qt.Key_Backspace) or (
- event.key() == qt.Qt.Key_Z and event.modifiers() & qt.Qt.ControlModifier):
- if len(self._markersAndResult) > 0:
- legend, _ = self._markersAndResult.pop()
- self._plot.remove(legend=legend, kind='marker')
+ def exec_(self):
+ """Request user inputs
- self._updateStatusBar()
- return True # Stop further handling of those keys
+ :return: List of selection points information
+ """
+ plot = self.parent()
+ if plot is None:
+ return
- elif event.key() == qt.Qt.Key_Return:
- self.quit()
- return True # Stop further handling of those keys
+ window = plot.window() # Retrieve window containing PlotWidget
- return super(_GInputHandler, self).eventFilter(obj, event)
+ # Add ROI point interactive mode action
+ for toolbar in window.findChildren(qt.QToolBar):
+ if isinstance(toolbar, InteractiveModeToolBar):
+ break
+ else: # Add a toolbar
+ toolbar = qt.QToolBar()
+ window.addToolBar(toolbar)
+ toolbar.addAction(self.getInteractionModeAction(roi_items.PointROI))
- def exec_(self):
- """Run blocking ginput handler
+ super(_GInputHandler, self).exec_(roiClass=roi_items.PointROI, timeout=self._timeout)
- :returns: List of selected points
- """
- # Bootstrap
- self._plot.setInteractiveMode(mode='zoom')
- self._handleInteractiveModeChanged(None)
- self._plot.sigInteractiveModeChanged.connect(
- self._handleInteractiveModeChanged)
-
- self._plot.installEventFilter(self)
+ if isinstance(toolbar, InteractiveModeToolBar):
+ toolbar.removeAction(self.getInteractionModeAction(roi_items.PointROI))
+ else:
+ toolbar.setParent(None)
- # Run
- if self._timeout:
- timeoutTimer = qt.QTimer()
- timeoutTimer.timeout.connect(self._updateStatusBar)
- timeoutTimer.start(1000)
+ return tuple(self.__selections.values())
- self._endTime = time.time() + self._timeout
- self._updateStatusBar()
+ def __updateSelection(self, roi):
+ """Perform picking and update selection list
- returnCode = super(_GInputHandler, self).exec_()
+ :param RegionOfInterest roi:
+ """
- timeoutTimer.stop()
- else:
- returnCode = super(_GInputHandler, self).exec_()
-
- # Clean-up
- self._plot.removeEventFilter(self)
-
- self._plot.sigInteractiveModeChanged.disconnect(
- self._handleInteractiveModeChanged)
-
- currentMode = self._plot.getInteractiveMode()
- if currentMode['mode'] == 'zoom': # Stop handling mouse click
- self._plot.sigPlotSignal.disconnect(self._handleSelect)
-
- self._plot.statusBar().clearMessage()
-
- points = tuple(result for _, result in self._markersAndResult)
-
- for legend, _ in self._markersAndResult:
- self._plot.remove(legend=legend, kind='marker')
- self._markersAndResult = []
-
- return points if returnCode == 0 else ()
-
- def _updateStatusBar(self):
- """Update status bar message"""
- msg = 'ginput: %d/%d input points' % (len(self._markersAndResult),
- self._totalPoints)
- if self._timeout:
- remaining = self._endTime - time.time()
- if remaining < 0:
- self.quit()
- return
- msg += ', %d seconds remaining' % max(1, int(remaining))
-
- currentMode = self._plot.getInteractiveMode()
- if currentMode['mode'] != 'zoom':
- msg += ' (Use zoom mode to add/remove points)'
-
- self._plot.statusBar().showMessage(msg)
-
- def _handleSelect(self, event):
- """Handle mouse events"""
- if event['event'] == 'mouseClicked' and event['button'] == 'left':
- x, y = event['x'], event['y']
- xPixel, yPixel = event['xpixel'], event['ypixel']
-
- # Add marker
- legend = "sx.ginput %d" % len(self._markersAndResult)
- self._plot.addMarker(
- x, y,
- legend=legend,
- text='%d' % len(self._markersAndResult),
- color='red',
- draggable=False)
-
- # Pick item at selected position
- picked = self._plot._pickImageOrCurve(xPixel, yPixel)
-
- if picked is None:
- result = _GInputResult((x, y),
- item=None,
- indices=numpy.array((), dtype=int),
- data=None)
-
- elif picked[0] == 'curve':
- curve = picked[1]
- indices = picked[2]
- xData = curve.getXData(copy=False)[indices]
- yData = curve.getYData(copy=False)[indices]
- result = _GInputResult((x, y),
- item=curve,
- indices=indices,
- data=numpy.array((xData, yData)).T)
-
- elif picked[0] == 'image':
- image = picked[1]
- # Get corresponding coordinate in image
- origin = image.getOrigin()
- scale = image.getScale()
- column = int((x - origin[0]) / float(scale[0]))
- row = int((y - origin[1]) / float(scale[1]))
- data = image.getData(copy=False)[row, column]
- result = _GInputResult((x, y),
- item=image,
- indices=(row, column),
- data=data)
-
- self._markersAndResult.append((legend, result))
- self._updateStatusBar()
- if len(self._markersAndResult) == self._totalPoints:
- self.quit()
-
- def _handleInteractiveModeChanged(self, source):
- """Handle change of interactive mode in the plot
-
- :param source: Objects that triggered the mode change
+ plot = self.parent()
+ if plot is None:
+ return # No plot, abort
+
+ if not isinstance(roi, roi_items.PointROI):
+ # Only handle points
+ raise RuntimeError("Unexpected item")
+
+ x, y = roi.getPosition()
+ xPixel, yPixel = plot.dataToPixel(x, y, axis='left', check=False)
+
+ # Pick item at selected position
+ picked = plot._pickImageOrCurve(xPixel, yPixel)
+
+ if picked is None:
+ result = _GInputResult((x, y),
+ item=None,
+ indices=numpy.array((), dtype=int),
+ data=None)
+
+ elif picked[0] == 'curve':
+ curve = picked[1]
+ indices = picked[2]
+ xData = curve.getXData(copy=False)[indices]
+ yData = curve.getYData(copy=False)[indices]
+ result = _GInputResult((x, y),
+ item=curve,
+ indices=indices,
+ data=numpy.array((xData, yData)).T)
+
+ elif picked[0] == 'image':
+ image = picked[1]
+ # Get corresponding coordinate in image
+ origin = image.getOrigin()
+ scale = image.getScale()
+ column = int((x - origin[0]) / float(scale[0]))
+ row = int((y - origin[1]) / float(scale[1]))
+ data = image.getData(copy=False)[row, column]
+ result = _GInputResult((x, y),
+ item=image,
+ indices=(row, column),
+ data=data)
+
+ self.__selections[roi] = result
+
+ def __added(self, roi):
+ """Handle new ROI added
+
+ :param RegionOfInterest roi:
"""
- mode = self._plot.getInteractiveMode()
- if mode['mode'] == 'zoom': # Handle click events
- self._plot.sigPlotSignal.connect(self._handleSelect)
- else: # Do not handle click event
- self._plot.sigPlotSignal.disconnect(self._handleSelect)
- self._updateStatusBar()
+ if isinstance(roi, roi_items.PointROI):
+ # Only handle points
+ roi.setLabel('%d' % len(self.__selections))
+ self.__updateSelection(roi)
+ roi.sigRegionChanged.connect(self.__regionChanged)
+
+ def __removed(self, roi):
+ """Handle ROI removed"""
+ if self.__selections.pop(roi, None) is not None:
+ roi.sigRegionChanged.disconnect(self.__regionChanged)
+
+ def __regionChanged(self):
+ """Handle update of a ROI"""
+ roi = self.sender()
+ self.__updateSelection(roi)
def ginput(n=1, timeout=30, plot=None):
@@ -537,7 +593,7 @@ def ginput(n=1, timeout=30, plot=None):
if widget.isVisible():
plot = widget
break
- else: # If no plot widgets are visible, take most recent one
+ else: # If no plot widget is visible, take the most recent one
try:
plot = _plots[0]
except IndexError:
@@ -549,6 +605,7 @@ def ginput(n=1, timeout=30, plot=None):
_logger.warning('No plot available to perform ginput, create one')
plot = Plot1D()
plot.show()
+ _plots.insert(0, plot)
plot.raise_() # So window becomes the top level one
diff --git a/silx/sx/_plot3d.py b/silx/sx/_plot3d.py
index 3e67fe0..42ebf80 100644
--- a/silx/sx/_plot3d.py
+++ b/silx/sx/_plot3d.py
@@ -27,7 +27,7 @@
__authors__ = ["T. Vincent"]
__license__ = "MIT"
-__date__ = "07/02/2018"
+__date__ = "24/04/2018"
from collections import Iterable
@@ -38,8 +38,8 @@ from ..gui import qt
from ..gui.plot3d.SceneWindow import SceneWindow
from ..gui.plot3d.ScalarFieldView import ScalarFieldView
from ..gui.plot3d import SFViewParamTree
-from ..gui.plot.Colormap import Colormap
-from ..gui.plot.Colors import rgba
+from ..gui.colors import Colormap
+from ..gui.colors import rgba
_logger = logging.getLogger(__name__)
@@ -148,7 +148,7 @@ def contour3d(scalars,
treeView.setSfView(scalarField) # Attach the parameter tree to the view
# Add the parameter tree to the main window in a dock widget
- dock = qt.QDockWidget()
+ dock = qt.QDockWidget(scalarField)
dock.setWindowTitle('Parameters')
dock.setWidget(treeView)
scalarField.addDockWidget(qt.Qt.RightDockWidgetArea, dock)
diff --git a/silx/sx/test/__init__.py b/silx/sx/test/__init__.py
new file mode 100644
index 0000000..c9401b6
--- /dev/null
+++ b/silx/sx/test/__init__.py
@@ -0,0 +1,38 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/03/2018"
+
+import unittest
+
+
+from . import test_sx
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(test_sx.suite())
+ return test_suite
diff --git a/silx/sx/test/test_sx.py b/silx/sx/test/test_sx.py
new file mode 100644
index 0000000..9de1f8b
--- /dev/null
+++ b/silx/sx/test/test_sx.py
@@ -0,0 +1,288 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+
+import logging
+import unittest
+import numpy
+
+from silx.utils.testutils import ParametricTestCase
+from silx.test.utils import test_options
+
+from silx.gui import qt
+# load TestCaseQt before sx
+from silx.gui.test.utils import TestCaseQt
+from silx.gui.colors import rgba
+from silx import sx
+
+
+_logger = logging.getLogger(__name__)
+
+
+class SXTest(TestCaseQt, ParametricTestCase):
+ """Test the sx module"""
+
+ def _expose_and_close(self, plot):
+ self.qWaitForWindowExposed(plot)
+ self.qapp.processEvents()
+ plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ plot.close()
+
+ def test_plot(self):
+ """Test plot function"""
+ y = numpy.random.random(100)
+ x = numpy.arange(len(y)) * 0.5
+
+ # Nothing
+ plt = sx.plot()
+ self._expose_and_close(plt)
+
+ # y
+ plt = sx.plot(y, title='y')
+ self._expose_and_close(plt)
+
+ # y, style
+ plt = sx.plot(y, 'blued ', title='y, "blued "')
+ self._expose_and_close(plt)
+
+ # x, y
+ plt = sx.plot(x, y, title='x, y')
+ self._expose_and_close(plt)
+
+ # x, y, style
+ plt = sx.plot(x, y, 'ro-', xlabel='x', title='x, y, "ro-"')
+ self._expose_and_close(plt)
+
+ # x, y, style, y
+ plt = sx.plot(x, y, 'ro-', y ** 2, xlabel='x', ylabel='y',
+ title='x, y, "ro-", y ** 2')
+ self._expose_and_close(plt)
+
+ # x, y, style, y, style
+ plt = sx.plot(x, y, 'ro-', y ** 2, 'b--',
+ title='x, y, "ro-", y ** 2, "b--"')
+ self._expose_and_close(plt)
+
+ # x, y, style, x, y, style
+ plt = sx.plot(x, y, 'ro-', x, y ** 2, 'b--',
+ title='x, y, "ro-", x, y ** 2, "b--"')
+ self._expose_and_close(plt)
+
+ # x, y, x, y
+ plt = sx.plot(x, y, x, y ** 2, title='x, y, x, y ** 2')
+ self._expose_and_close(plt)
+
+ def test_imshow(self):
+ """Test imshow function"""
+ img = numpy.arange(100.).reshape(10, 10) + 1
+
+ # Nothing
+ plt = sx.imshow()
+ self._expose_and_close(plt)
+
+ # image
+ plt = sx.imshow(img)
+ self._expose_and_close(plt)
+
+ # image, gray cmap
+ plt = sx.imshow(img, cmap='jet', title='jet cmap')
+ self._expose_and_close(plt)
+
+ # image, log cmap
+ plt = sx.imshow(img, norm='log', title='log cmap')
+ self._expose_and_close(plt)
+
+ # image, fixed range
+ plt = sx.imshow(img, vmin=10, vmax=20,
+ title='[10,20] cmap')
+ self._expose_and_close(plt)
+
+ # image, keep ratio
+ plt = sx.imshow(img, aspect=True,
+ title='keep ratio')
+ self._expose_and_close(plt)
+
+ # image, change origin and scale
+ plt = sx.imshow(img, origin=(10, 10), scale=(2, 2),
+ title='origin=(10, 10), scale=(2, 2)')
+ self._expose_and_close(plt)
+
+ # image, origin='lower'
+ plt = sx.imshow(img, origin='upper', title='origin="lower"')
+ self._expose_and_close(plt)
+
+ def test_scatter(self):
+ """Test scatter function"""
+ x = numpy.arange(100)
+ y = numpy.arange(100)
+ values = numpy.arange(100)
+
+ # simple scatter
+ plt = sx.scatter(x, y, values)
+ self._expose_and_close(plt)
+
+ # No value
+ plt = sx.scatter(x, y, values)
+ self._expose_and_close(plt)
+
+ # single value
+ plt = sx.scatter(x, y, 10.)
+ self._expose_and_close(plt)
+
+ # set size
+ plt = sx.scatter(x, y, values, size=20)
+ self._expose_and_close(plt)
+
+ # set colormap
+ plt = sx.scatter(x, y, values, cmap='jet')
+ self._expose_and_close(plt)
+
+ # set colormap range
+ plt = sx.scatter(x, y, values, vmin=2, vmax=50)
+ self._expose_and_close(plt)
+
+ # set colormap normalisation
+ plt = sx.scatter(x, y, values, norm='log')
+ self._expose_and_close(plt)
+
+ def test_ginput(self):
+ """Test ginput function
+
+ This does NOT perform interactive tests
+ """
+
+ for create_plot in (sx.plot, sx.imshow, sx.scatter):
+ with self.subTest(create_plot.__name__):
+ plt = create_plot()
+ self.qWaitForWindowExposed(plt)
+ self.qapp.processEvents()
+
+ result = sx.ginput(1, timeout=0.1)
+ self.assertEqual(len(result), 0)
+
+ plt.setAttribute(qt.Qt.WA_DeleteOnClose)
+ plt.close()
+
+ @unittest.skipUnless(test_options.WITH_GL_TEST,
+ test_options.WITH_GL_TEST_REASON)
+ def test_contour3d(self):
+ """Test contour3d function"""
+ coords = numpy.linspace(-10, 10, 64)
+ z = coords.reshape(-1, 1, 1)
+ y = coords.reshape(1, -1, 1)
+ x = coords.reshape(1, 1, -1)
+ data = numpy.sin(x * y * z) / (x * y * z)
+
+ # Just data
+ window = sx.contour3d(data)
+
+ isosurfaces = window.getIsosurfaces()
+ self.assertEqual(len(isosurfaces), 1)
+
+ self._expose_and_close(window)
+ if not window.getPlot3DWidget().isValid():
+ self.skipTest("OpenGL context is not valid")
+
+ # N contours + color
+ colors = ['red', 'green', 'blue']
+ window = sx.contour3d(data, copy=False, contours=len(colors),
+ color=colors)
+
+ isosurfaces = window.getIsosurfaces()
+ self.assertEqual(len(isosurfaces), len(colors))
+ for iso, color in zip(isosurfaces, colors):
+ self.assertEqual(rgba(iso.getColor()), rgba(color))
+
+ self._expose_and_close(window)
+
+ # by isolevel, single color
+ contours = 0.2, 0.5
+ window = sx.contour3d(data, copy=False, contours=contours,
+ color='yellow')
+
+ isosurfaces = window.getIsosurfaces()
+ self.assertEqual(len(isosurfaces), len(contours))
+ for iso, level in zip(isosurfaces, contours):
+ self.assertEqual(iso.getLevel(), level)
+ self.assertEqual(rgba(iso.getColor()),
+ rgba('yellow'))
+
+ self._expose_and_close(window)
+
+ # Single isolevel, colormap
+ window = sx.contour3d(data, copy=False, contours=0.5,
+ colormap='gray', vmin=0.6, opacity=0.4)
+
+ isosurfaces = window.getIsosurfaces()
+ self.assertEqual(len(isosurfaces), 1)
+ self.assertEqual(isosurfaces[0].getLevel(), 0.5)
+ self.assertEqual(rgba(isosurfaces[0].getColor()),
+ (0., 0., 0., 0.4))
+
+ self._expose_and_close(window)
+
+ @unittest.skipUnless(test_options.WITH_GL_TEST,
+ test_options.WITH_GL_TEST_REASON)
+ def test_points3d(self):
+ """Test points3d function"""
+ x = numpy.random.random(1024)
+ y = numpy.random.random(1024)
+ z = numpy.random.random(1024)
+ values = numpy.random.random(1024)
+
+ # 3D positions, no value
+ window = sx.points3d(x, y, z)
+
+ self._expose_and_close(window)
+ if not window.getSceneWidget().isValid():
+ self.skipTest("OpenGL context is not valid")
+
+ # 3D positions, values
+ window = sx.points3d(x, y, z, values, mode='2dsquare',
+ colormap='magma', vmin=0.4, vmax=0.5)
+ self._expose_and_close(window)
+
+ # 2D positions, no value
+ window = sx.points3d(x, y)
+ self._expose_and_close(window)
+
+ # 2D positions, values
+ window = sx.points3d(x, y, values=values, mode=',',
+ colormap='magma', vmin=0.4, vmax=0.5)
+ self._expose_and_close(window)
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(SXTest))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')