diff options
Diffstat (limited to 'silx/gui')
182 files changed, 63585 insertions, 0 deletions
diff --git a/silx/gui/__init__.py b/silx/gui/__init__.py new file mode 100644 index 0000000..6baf238 --- /dev/null +++ b/silx/gui/__init__.py @@ -0,0 +1,29 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Set of Qt widgets""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "23/05/2016" diff --git a/silx/gui/_glutils/Context.py b/silx/gui/_glutils/Context.py new file mode 100644 index 0000000..7600992 --- /dev/null +++ b/silx/gui/_glutils/Context.py @@ -0,0 +1,63 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Abstraction of OpenGL context. + +It defines a way to get current OpenGL context to support multiple +OpenGL contexts. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "25/07/2016" + + +# context ##################################################################### + + +def _defaultGLContextGetter(): + return None + +_glContextGetter = _defaultGLContextGetter + + +def getGLContext(): + """Returns platform dependent object of current OpenGL context. + + This is useful to associate OpenGL resources with the context they are + created in. + + :return: Platform specific OpenGL context + :rtype: None by default or a platform dependent object""" + return _glContextGetter() + + +def setGLContextGetter(getter=_defaultGLContextGetter): + """Set a platform dependent function to retrieve the current OpenGL context + + :param getter: Platform dependent GL context getter + :type getter: Function with no args returning the current OpenGL context + """ + global _glContextGetter + _glContextGetter = getter diff --git a/silx/gui/_glutils/FramebufferTexture.py b/silx/gui/_glutils/FramebufferTexture.py new file mode 100644 index 0000000..b01eb41 --- /dev/null +++ b/silx/gui/_glutils/FramebufferTexture.py @@ -0,0 +1,164 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Association of a texture and a framebuffer object for off-screen rendering. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "25/07/2016" + + +import logging + +from . import gl +from .Texture import Texture + + +_logger = logging.getLogger(__name__) + + +class FramebufferTexture(object): + """Framebuffer with a texture. + + Aimed at off-screen rendering to texture. + + :param internalFormat: OpenGL texture internal format + :param shape: Shape (height, width) of the framebuffer and texture + :type shape: 2-tuple of int + :param stencilFormat: Stencil renderbuffer format + :param depthFormat: Depth renderbuffer format + :param kwargs: Extra arguments for :class:`Texture` constructor + """ + + _PACKED_FORMAT = gl.GL_DEPTH24_STENCIL8, gl.GL_DEPTH_STENCIL + + def __init__(self, + internalFormat, + shape, + stencilFormat=gl.GL_DEPTH24_STENCIL8, + depthFormat=gl.GL_DEPTH24_STENCIL8, + **kwargs): + + self._texture = Texture(internalFormat, shape=shape, **kwargs) + + self._previousFramebuffer = 0 # Used by with statement + + self._name = gl.glGenFramebuffers(1) + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self._name) + + # Attachments + gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, + gl.GL_COLOR_ATTACHMENT0, + gl.GL_TEXTURE_2D, + self._texture.name, + 0) + + height, width = self._texture.shape + + if stencilFormat is not None: + self._stencilId = gl.glGenRenderbuffers(1) + gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self._stencilId) + gl.glRenderbufferStorage(gl.GL_RENDERBUFFER, + stencilFormat, + width, height) + gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, + gl.GL_STENCIL_ATTACHMENT, + gl.GL_RENDERBUFFER, + self._stencilId) + else: + self._stencilId = None + + if depthFormat is not None: + if self._stencilId and depthFormat in self._PACKED_FORMAT: + self._depthId = self._stencilId + else: + self._depthId = gl.glGenRenderbuffers(1) + gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self._depthId) + gl.glRenderbufferStorage(gl.GL_RENDERBUFFER, + depthFormat, + width, height) + gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, + gl.GL_DEPTH_ATTACHMENT, + gl.GL_RENDERBUFFER, + self._depthId) + else: + self._depthId = None + + assert gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) == \ + gl.GL_FRAMEBUFFER_COMPLETE + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0) + + @property + def shape(self): + """Shape of the framebuffer (height, width)""" + return self._texture.shape + + @property + def texture(self): + """The texture this framebuffer is rendering to. + + The life-cycle of the texture is managed by this object""" + return self._texture + + @property + def name(self): + """OpenGL name of the framebuffer""" + if self._name is not None: + return self._name + else: + raise RuntimeError("No OpenGL framebuffer resource, \ + discard has already been called") + + def bind(self): + """Bind this framebuffer for rendering""" + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.name) + + # with statement + + def __enter__(self): + self._previousFramebuffer = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING) + self.bind() + + def __exit__(self, exctype, excvalue, traceback): + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self._previousFramebuffer) + + def discard(self): + """Delete associated OpenGL resources including texture""" + if self._name is not None: + gl.glDeleteFramebuffers(self._name) + self._name = None + + if self._stencilId is not None: + gl.glDeleteRenderbuffers(self._stencilId) + if self._stencilId == self._depthId: + self._depthId = None + self._stencilId = None + if self._depthId is not None: + gl.glDeleteRenderbuffers(self._depthId) + self._depthId = None + + self._texture.discard() # Also discard the texture + else: + _logger.warning("Discard has already been called") diff --git a/silx/gui/_glutils/Program.py b/silx/gui/_glutils/Program.py new file mode 100644 index 0000000..48c12f5 --- /dev/null +++ b/silx/gui/_glutils/Program.py @@ -0,0 +1,202 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a class to handle shader program compilation.""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "25/07/2016" + + +import logging + +import numpy + +from . import gl +from .Context import getGLContext + +_logger = logging.getLogger(__name__) + + +class Program(object): + """Wrap OpenGL shader program. + + The program is compiled lazily (i.e., at first program :meth:`use`). + When the program is compiled, it stores attributes and uniforms locations. + So, attributes and uniforms must be used after :meth:`use`. + + This object supports multiple OpenGL contexts. + + :param str vertexShader: The source of the vertex shader. + :param str fragmentShader: The source of the fragment shader. + :param str attrib0: + Attribute's name to bind to position 0 (default: 'position'). + On certain platform, this attribute MUST be active and with an + array attached to it in order for the rendering to occur.... + """ + + def __init__(self, vertexShader, fragmentShader, + attrib0='position'): + self._vertexShader = vertexShader + self._fragmentShader = fragmentShader + self._attrib0 = attrib0 + self._programs = {} + + @staticmethod + def _compileGL(vertexShader, fragmentShader, attrib0): + program = gl.glCreateProgram() + + gl.glBindAttribLocation(program, 0, attrib0.encode('ascii')) + + vertex = gl.glCreateShader(gl.GL_VERTEX_SHADER) + gl.glShaderSource(vertex, vertexShader) + gl.glCompileShader(vertex) + if gl.glGetShaderiv(vertex, gl.GL_COMPILE_STATUS) != gl.GL_TRUE: + raise RuntimeError(gl.glGetShaderInfoLog(vertex)) + gl.glAttachShader(program, vertex) + gl.glDeleteShader(vertex) + + fragment = gl.glCreateShader(gl.GL_FRAGMENT_SHADER) + gl.glShaderSource(fragment, fragmentShader) + gl.glCompileShader(fragment) + if gl.glGetShaderiv(fragment, + gl.GL_COMPILE_STATUS) != gl.GL_TRUE: + raise RuntimeError(gl.glGetShaderInfoLog(fragment)) + gl.glAttachShader(program, fragment) + gl.glDeleteShader(fragment) + + gl.glLinkProgram(program) + if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE: + raise RuntimeError(gl.glGetProgramInfoLog(program)) + + attributes = {} + for index in range(gl.glGetProgramiv(program, + gl.GL_ACTIVE_ATTRIBUTES)): + name = gl.glGetActiveAttrib(program, index)[0] + namestr = name.decode('ascii') + attributes[namestr] = gl.glGetAttribLocation(program, name) + + uniforms = {} + for index in range(gl.glGetProgramiv(program, gl.GL_ACTIVE_UNIFORMS)): + name = gl.glGetActiveUniform(program, index)[0] + namestr = name.decode('ascii') + uniforms[namestr] = gl.glGetUniformLocation(program, name) + + return program, attributes, uniforms + + def _getProgramInfo(self): + glcontext = getGLContext() + if glcontext not in self._programs: + raise RuntimeError( + "Program was not compiled for current OpenGL context.") + return self._programs[glcontext] + + @property + def attributes(self): + """Vertex attributes names and locations as a dict of {str: int}. + + WARNING: + Read-only usage. + To use only with a valid OpenGL context and after :meth:`use` + has been called for this context. + """ + return self._getProgramInfo()[1] + + @property + def uniforms(self): + """Program uniforms names and locations as a dict of {str: int}. + + WARNING: + Read-only usage. + To use only with a valid OpenGL context and after :meth:`use` + has been called for this context. + """ + return self._getProgramInfo()[2] + + @property + def program(self): + """OpenGL id of the program. + + WARNING: + To use only with a valid OpenGL context and after :meth:`use` + has been called for this context. + """ + return self._getProgramInfo()[0] + + # def discard(self): + # pass # Not implemented yet + + def use(self): + """Make use of the program, compiling it if necessary""" + glcontext = getGLContext() + + if glcontext not in self._programs: + self._programs[glcontext] = self._compileGL( + self._vertexShader, + self._fragmentShader, + self._attrib0) + + if _logger.getEffectiveLevel() <= logging.DEBUG: + gl.glValidateProgram(self.program) + if gl.glGetProgramiv( + self.program, gl.GL_VALIDATE_STATUS) != gl.GL_TRUE: + _logger.debug('Cannot validate program: %s', + gl.glGetProgramInfoLog(self.program)) + + gl.glUseProgram(self.program) + + def setUniformMatrix(self, name, value, transpose=True, safe=False): + """Wrap glUniformMatrix[2|3|4]fv + + :param str name: The name of the uniform. + :param value: The 2D matrix (or the array of matrices, 3D). + Matrices are 2x2, 3x3 or 4x4. + :type value: numpy.ndarray with 2 or 3 dimensions of float32 + :param bool transpose: Whether to transpose (True, default) or not. + :param bool safe: False: raise an error if no uniform with this name; + True: silently ignores it. + + :raises KeyError: if no uniform corresponds to name. + """ + assert value.dtype == numpy.float32 + + shape = value.shape + assert len(shape) in (2, 3) + assert shape[-1] in (2, 3, 4) + assert shape[-1] == shape[-2] # As in OpenGL|ES 2.0 + + location = self.uniforms.get(name) + if location is not None: + count = 1 if len(shape) == 2 else shape[0] + transpose = gl.GL_TRUE if transpose else gl.GL_FALSE + + if shape[-1] == 2: + gl.glUniformMatrix2fv(location, count, transpose, value) + elif shape[-1] == 3: + gl.glUniformMatrix3fv(location, count, transpose, value) + elif shape[-1] == 4: + gl.glUniformMatrix4fv(location, count, transpose, value) + + elif not safe: + raise KeyError('No uniform: %s' % name) diff --git a/silx/gui/_glutils/Texture.py b/silx/gui/_glutils/Texture.py new file mode 100644 index 0000000..9f09a86 --- /dev/null +++ b/silx/gui/_glutils/Texture.py @@ -0,0 +1,308 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a class wrapping OpenGL 2D and 3D texture.""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "04/10/2016" + + +import collections +from ctypes import c_void_p +import logging + +import numpy + +from . import gl, utils + + +_logger = logging.getLogger(__name__) + + +class Texture(object): + """Base class to wrap OpenGL 2D and 3D texture + + :param internalFormat: OpenGL texture internal format + :param data: The data to copy to the texture or None for an empty texture + :type data: numpy.ndarray or None + :param format_: Input data format if different from internalFormat + :param shape: If data is None, shape of the texture + :type shape: 2 or 3-tuple of int (height, width) or (depth, height, width) + :param int texUnit: The texture unit to use + :param minFilter: OpenGL texture minimization filter (default: GL_NEAREST) + :param magFilter: OpenGL texture magnification filter (default: GL_LINEAR) + :param wrap: Texture wrap mode for dimensions: (t, s) or (r, t, s) + If a single value is provided, it used for all dimensions. + :type wrap: OpenGL wrap mode or 2 or 3-tuple of wrap mode + """ + + def __init__(self, internalFormat, data=None, format_=None, + shape=None, texUnit=0, + minFilter=None, magFilter=None, wrap=None): + + self._internalFormat = internalFormat + if format_ is None: + format_ = self.internalFormat + + if data is None: + assert shape is not None + else: + assert shape is None + data = numpy.array(data, copy=False, order='C') + if format_ != gl.GL_RED: + shape = data.shape[:-1] # Last dimension is channels + else: + shape = data.shape + + assert len(shape) in (2, 3) + self._shape = tuple(shape) + self._ndim = len(shape) + + self.texUnit = texUnit + + self._name = gl.glGenTextures(1) + self.bind(self.texUnit) + + self._minFilter = None + self.minFilter = minFilter if minFilter is not None else gl.GL_NEAREST + + self._magFilter = None + self.magFilter = magFilter if magFilter is not None else gl.GL_LINEAR + + if wrap is not None: + if not isinstance(wrap, collections.Iterable): + wrap = [wrap] * self.ndim + + assert len(wrap) == self.ndim + + gl.glTexParameter(self.target, + gl.GL_TEXTURE_WRAP_S, + wrap[-1]) + gl.glTexParameter(self.target, + gl.GL_TEXTURE_WRAP_T, + wrap[-2]) + if self.ndim == 3: + gl.glTexParameter(self.target, + gl.GL_TEXTURE_WRAP_R, + wrap[0]) + + gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) + + # This are the defaults, useless to set if not modified + # gl.glPixelStorei(gl.GL_UNPACK_ROW_LENGTH, 0) + # gl.glPixelStorei(gl.GL_UNPACK_SKIP_PIXELS, 0) + # gl.glPixelStorei(gl.GL_UNPACK_SKIP_ROWS, 0) + # gl.glPixelStorei(gl.GL_UNPACK_IMAGE_HEIGHT, 0) + # gl.glPixelStorei(gl.GL_UNPACK_SKIP_IMAGES, 0) + + if data is None: + data = c_void_p(0) + type_ = gl.GL_UNSIGNED_BYTE + else: + type_ = utils.numpyToGLType(data.dtype) + + if self.ndim == 2: + _logger.debug( + 'Creating 2D texture shape: (%d, %d),' + ' internal format: %s, format: %s, type: %s', + self.shape[0], self.shape[1], + str(self.internalFormat), str(format_), str(type_)) + + gl.glTexImage2D( + gl.GL_TEXTURE_2D, + 0, + self.internalFormat, + self.shape[1], + self.shape[0], + 0, + format_, + type_, + data) + else: + _logger.debug( + 'Creating 3D texture shape: (%d, %d, %d),' + ' internal format: %s, format: %s, type: %s', + self.shape[0], self.shape[1], self.shape[2], + str(self.internalFormat), str(format_), str(type_)) + + gl.glTexImage3D( + gl.GL_TEXTURE_3D, + 0, + self.internalFormat, + self.shape[2], + self.shape[1], + self.shape[0], + 0, + format_, + type_, + data) + + gl.glBindTexture(self.target, 0) + + @property + def target(self): + """OpenGL target type of this texture""" + return gl.GL_TEXTURE_2D if self.ndim == 2 else gl.GL_TEXTURE_3D + + @property + def ndim(self): + """The number of dimensions: 2 or 3""" + return self._ndim + + @property + def internalFormat(self): + """Texture internal format""" + return self._internalFormat + + @property + def shape(self): + """Shape of the texture: (height, width) or (depth, height, width)""" + return self._shape + + @property + def name(self): + """OpenGL texture name""" + if self._name is not None: + return self._name + else: + raise RuntimeError( + "No OpenGL texture resource, discard has already been called") + + @property + def minFilter(self): + """Minifying function parameter (GL_TEXTURE_MIN_FILTER)""" + return self._minFilter + + @minFilter.setter + def minFilter(self, minFilter): + if minFilter != self.minFilter: + self._minFilter = minFilter + self.bind() + gl.glTexParameter(self.target, + gl.GL_TEXTURE_MIN_FILTER, + self.minFilter) + + @property + def magFilter(self): + """Magnification function parameter (GL_TEXTURE_MAG_FILTER)""" + return self._magFilter + + @magFilter.setter + def magFilter(self, magFilter): + if magFilter != self.magFilter: + self._magFilter = magFilter + self.bind() + gl.glTexParameter(self.target, + gl.GL_TEXTURE_MAG_FILTER, + self.magFilter) + + def discard(self): + """Delete associated OpenGL texture""" + if self._name is not None: + gl.glDeleteTextures(self._name) + self._name = None + else: + _logger.warning("Discard as already been called") + + def bind(self, texUnit=None): + """Bind the texture to a texture unit. + + :param int texUnit: The texture unit to use + """ + if texUnit is None: + texUnit = self.texUnit + gl.glActiveTexture(gl.GL_TEXTURE0 + texUnit) + gl.glBindTexture(self.target, self.name) + + # with statement + + def __enter__(self): + self.bind() + + def __exit__(self, exc_type, exc_val, exc_tb): + gl.glActiveTexture(gl.GL_TEXTURE0 + self.texUnit) + gl.glBindTexture(self.target, 0) + + def update(self, + format_, + data, + offset=(0, 0, 0), + texUnit=None): + """Update the content of the texture. + + Texture is not resized, so data must fit into texture with the + given offset. + + :param format_: The OpenGL format of the data + :param data: The data to use to update the texture + :param offset: The offset in the texture where to copy the data + :type offset: 2 or 3-tuple of int + :param int texUnit: + The texture unit to use (default: the one provided at init) + """ + data = numpy.array(data, copy=False, order='C') + + assert data.ndim == self.ndim + assert len(offset) >= self.ndim + for i in range(self.ndim): + assert offset[i] + data.shape[i] <= self.shape[i] + + gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) + + # This are the defaults, useless to set if not modified + # gl.glPixelStorei(gl.GL_UNPACK_ROW_LENGTH, 0) + # gl.glPixelStorei(gl.GL_UNPACK_SKIP_PIXELS, 0) + # gl.glPixelStorei(gl.GL_UNPACK_SKIP_ROWS, 0) + # gl.glPixelStorei(gl.GL_UNPACK_IMAGE_HEIGHT, 0) + # gl.glPixelStorei(gl.GL_UNPACK_SKIP_IMAGES, 0) + + self.bind(texUnit) + + type_ = utils.numpyToGLType(data.dtype) + + if self.ndim == 2: + gl.glTexSubImage2D(gl.GL_TEXTURE_2D, + 0, + offset[1], + offset[0], + data.shape[1], + data.shape[0], + format_, + type_, + data) + gl.glBindTexture(gl.GL_TEXTURE_2D, 0) + else: + gl.glTexSubImage3D(gl.GL_TEXTURE_3D, + 0, + offset[2], + offset[1], + offset[0], + data.shape[2], + data.shape[1], + data.shape[0], + format_, + type_, + data) + gl.glBindTexture(gl.GL_TEXTURE_3D, 0) diff --git a/silx/gui/_glutils/VertexBuffer.py b/silx/gui/_glutils/VertexBuffer.py new file mode 100644 index 0000000..689b543 --- /dev/null +++ b/silx/gui/_glutils/VertexBuffer.py @@ -0,0 +1,266 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a class managing an OpenGL vertex buffer.""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "10/01/2017" + + +import logging +from ctypes import c_void_p +import numpy + +from . import gl +from .utils import numpyToGLType, sizeofGLType + + +_logger = logging.getLogger(__name__) + + +class VertexBuffer(object): + """Object handling an OpenGL vertex buffer object + + :param data: Data used to fill the vertex buffer + :type data: numpy.ndarray or None + :param int size: Size in bytes of the buffer or None for data size + :param usage: OpenGL vertex buffer expected usage pattern: + GL_STREAM_DRAW, GL_STATIC_DRAW (default) or GL_DYNAMIC_DRAW + :param target: Target buffer: + GL_ARRAY_BUFFER (default) or GL_ELEMENT_ARRAY_BUFFER + """ + # OpenGL|ES 2.0 subset: + _USAGES = gl.GL_STREAM_DRAW, gl.GL_STATIC_DRAW, gl.GL_DYNAMIC_DRAW + _TARGETS = gl.GL_ARRAY_BUFFER, gl.GL_ELEMENT_ARRAY_BUFFER + + def __init__(self, + data=None, + size=None, + usage=None, + target=None): + if usage is None: + usage = gl.GL_STATIC_DRAW + assert usage in self._USAGES + + if target is None: + target = gl.GL_ARRAY_BUFFER + assert target in self._TARGETS + + self._target = target + self._usage = usage + + self._name = gl.glGenBuffers(1) + self.bind() + + if data is None: + assert size is not None + self._size = size + gl.glBufferData(self._target, + self._size, + c_void_p(0), + self._usage) + else: + data = numpy.array(data, copy=False, order='C') + if size is not None: + assert size <= data.nbytes + + self._size = size or data.nbytes + gl.glBufferData(self._target, + self._size, + data, + self._usage) + + gl.glBindBuffer(self._target, 0) + + @property + def target(self): + """The target buffer of the vertex buffer""" + return self._target + + @property + def usage(self): + """The expected usage of the vertex buffer""" + return self._usage + + @property + def name(self): + """OpenGL Vertex Buffer object name (int)""" + if self._name is not None: + return self._name + else: + raise RuntimeError("No OpenGL buffer resource, \ + discard has already been called") + + @property + def size(self): + """Size in bytes of the Vertex Buffer Object (int)""" + if self._size is not None: + return self._size + else: + raise RuntimeError("No OpenGL buffer resource, \ + discard has already been called") + + def bind(self): + """Bind the vertex buffer""" + gl.glBindBuffer(self._target, self.name) + + def update(self, data, offset=0, size=None): + """Update vertex buffer content. + + :param numpy.ndarray data: The data to put in the vertex buffer + :param int offset: Offset in bytes in the buffer where to put the data + :param int size: If provided, size of data to copy + """ + data = numpy.array(data, copy=False, order='C') + if size is None: + size = data.nbytes + assert offset + size <= self.size + with self: + gl.glBufferSubData(self._target, offset, size, data) + + def discard(self): + """Delete the vertex buffer""" + if self._name is not None: + gl.glDeleteBuffers(self._name) + self._name = None + self._size = None + else: + _logger.warning("Discard has already been called") + + # with statement + + def __enter__(self): + self.bind() + + def __exit__(self, exctype, excvalue, traceback): + gl.glBindBuffer(self._target, 0) + + +class VertexBufferAttrib(object): + """Describes data stored in a vertex buffer + + Convenient class to store info for glVertexAttribPointer calls + + :param VertexBuffer vbo: The vertex buffer storing the data + :param int type_: The OpenGL type of the data + :param int size: The number of data elements stored in the VBO + :param int dimension: The number of `type_` element(s) in [1, 4] + :param int offset: Start offset of data in the vertex buffer + :param int stride: Data stride in the vertex buffer + """ + + _GL_TYPES = gl.GL_UNSIGNED_BYTE, gl.GL_FLOAT, gl.GL_INT + + def __init__(self, + vbo, + type_, + size, + dimension=1, + offset=0, + stride=0, + normalisation=False): + self.vbo = vbo + assert type_ in self._GL_TYPES + self.type_ = type_ + self.size = size + assert 1 <= dimension <= 4 + self.dimension = dimension + self.offset = offset + self.stride = stride + self.normalisation = bool(normalisation) + + @property + def itemsize(self): + """Size in bytes of a vertex buffer element (int)""" + return self.dimension * sizeofGLType(self.type_) + + itemSize = itemsize # Backward compatibility + + def setVertexAttrib(self, attribute): + """Call glVertexAttribPointer with objects information""" + normalisation = gl.GL_TRUE if self.normalisation else gl.GL_FALSE + with self.vbo: + gl.glVertexAttribPointer(attribute, + self.dimension, + self.type_, + normalisation, + self.stride, + c_void_p(self.offset)) + + def copy(self): + return VertexBufferAttrib(self.vbo, + self.type_, + self.size, + self.dimension, + self.offset, + self.stride, + self.normalisation) + + +def vertexBuffer(arrays, prefix=None, suffix=None, usage=None): + """Create a single vertex buffer from multiple 1D or 2D numpy arrays. + + It is possible to reserve memory before and after each array in the VBO + + :param arrays: Arrays of data to store + :type arrays: Iterable of numpy.ndarray + :param prefix: If given, number of elements to reserve before each array + :type prefix: Iterable of int or None + :param suffix: If given, number of elements to reserve after each array + :type suffix: Iterable of int or None + :param int usage: vertex buffer expected usage or None for default + :returns: List of VertexBufferAttrib objects sharing the same vertex buffer + """ + info = [] + vbosize = 0 + + if prefix is None: + prefix = (0,) * len(arrays) + if suffix is None: + suffix = (0,) * len(arrays) + + for data, pre, post in zip(arrays, prefix, suffix): + data = numpy.array(data, copy=False, order='C') + shape = data.shape + assert len(shape) <= 2 + type_ = numpyToGLType(data.dtype) + size = shape[0] + pre + post + dimension = 1 if len(shape) == 1 else shape[1] + sizeinbytes = size * dimension * sizeofGLType(type_) + sizeinbytes = 4 * ((sizeinbytes + 3) >> 2) # 4 bytes alignment + copyoffset = vbosize + pre * dimension * sizeofGLType(type_) + info.append((data, type_, size, dimension, + vbosize, sizeinbytes, copyoffset)) + vbosize += sizeinbytes + + vbo = VertexBuffer(size=vbosize, usage=usage) + + result = [] + for data, type_, size, dimension, offset, sizeinbytes, copyoffset in info: + copysize = data.shape[0] * dimension * sizeofGLType(type_) + vbo.update(data, offset=copyoffset, size=copysize) + result.append( + VertexBufferAttrib(vbo, type_, size, dimension, offset, 0)) + return result diff --git a/silx/gui/_glutils/__init__.py b/silx/gui/_glutils/__init__.py new file mode 100644 index 0000000..e86a58f --- /dev/null +++ b/silx/gui/_glutils/__init__.py @@ -0,0 +1,41 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This package provides utility functions to handle OpenGL resources. + +The :mod:`gl` module provides a wrapper to OpenGL based on PyOpenGL. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "25/07/2016" + + +# OpenGL convenient functions +from .Context import getGLContext, setGLContextGetter # noqa +from .FramebufferTexture import FramebufferTexture # noqa +from .Program import Program # noqa +from .Texture import Texture # noqa +from .VertexBuffer import VertexBuffer, VertexBufferAttrib, vertexBuffer # noqa +from .utils import sizeofGLType, isSupportedGLType, numpyToGLType # noqa diff --git a/silx/gui/_glutils/font.py b/silx/gui/_glutils/font.py new file mode 100644 index 0000000..566ae49 --- /dev/null +++ b/silx/gui/_glutils/font.py @@ -0,0 +1,152 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Text rasterisation feature leveraging Qt font and text layout support.""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "13/10/2016" + + +import logging +import sys +import numpy +from .. import qt +from .._utils import convertQImageToArray + + +_logger = logging.getLogger(__name__) + + +def getDefaultFontFamily(): + """Returns the default font family of the application""" + return qt.QApplication.instance().font().family() + + +# Font weights +ULTRA_LIGHT = 0 +"""Lightest characters: Minimum font weight""" + +LIGHT = 25 +"""Light characters""" + +NORMAL = 50 +"""Normal characters""" + +SEMI_BOLD = 63 +"""Between normal and bold characters""" + +BOLD = 74 +"""Thicker characters""" + +BLACK = 87 +"""Really thick characters""" + +ULTRA_BLACK = 99 +"""Thickest characters: Maximum font weight""" + + +def rasterText(text, font, + size=-1, + weight=-1, + italic=False, + devicePixelRatio=1.0): + """Raster text using Qt. + + It supports multiple lines. + + :param str text: The text to raster + :param font: Font name or QFont to use + :type font: str or :class:`QFont` + :param int size: + Font size in points + Used only if font is given as name. + :param int weight: + Font weight in [0, 99], see QFont.Weight. + Used only if font is given as name. + :param bool italic: + True for italic font (default: False). + Used only if font is given as name. + :param float devicePixelRatio: + The current ratio between device and device-independent pixel + (default: 1.0) + :return: Corresponding image in gray scale and baseline offset from top + :rtype: (HxW numpy.ndarray of uint8, int) + """ + if not text: + _logger.info("Trying to raster empty text, replaced by white space") + text = ' ' # Replace empty text by white space to produce an image + + if not isinstance(font, qt.QFont): + font = qt.QFont(font, size, weight, italic) + + metrics = qt.QFontMetrics(font) + size = metrics.size(qt.Qt.TextExpandTabs, text) + bounds = metrics.boundingRect( + qt.QRect(0, 0, size.width(), size.height()), + qt.Qt.TextExpandTabs, + text) + + if (devicePixelRatio != 1.0 and + not hasattr(qt.QImage, 'setDevicePixelRatio')): # Qt 4 + _logger.error('devicePixelRatio not supported') + devicePixelRatio = 1.0 + + # Add extra border and handle devicePixelRatio + width = bounds.width() * devicePixelRatio + 2 + # align line size to 32 bits to ease conversion to numpy array + width = 4 * ((width + 3) // 4) + image = qt.QImage(width, + bounds.height() * devicePixelRatio, + qt.QImage.Format_RGB888) + if (devicePixelRatio != 1.0 and + hasattr(image, 'setDevicePixelRatio')): # Qt 5 + image.setDevicePixelRatio(devicePixelRatio) + + # TODO if Qt5 use Format_Grayscale8 instead + image.fill(0) + + # Raster text + painter = qt.QPainter() + painter.begin(image) + painter.setPen(qt.Qt.white) + painter.setFont(font) + painter.drawText(bounds, qt.Qt.TextExpandTabs, text) + painter.end() + + array = convertQImageToArray(image) + + # RGB to R + array = numpy.ascontiguousarray(array[:, :, 0]) + + # Remove leading and trailing empty columns but one on each side + column_cumsum = numpy.cumsum(numpy.sum(array, axis=0)) + array = array[:, column_cumsum.argmin():column_cumsum.argmax() + 2] + + # Remove leading and trailing empty rows but one on each side + row_cumsum = numpy.cumsum(numpy.sum(array, axis=1)) + min_row = row_cumsum.argmin() + array = array[min_row:row_cumsum.argmax() + 2, :] + + return array, metrics.ascent() - min_row diff --git a/silx/gui/_glutils/gl.py b/silx/gui/_glutils/gl.py new file mode 100644 index 0000000..4b9a7bb --- /dev/null +++ b/silx/gui/_glutils/gl.py @@ -0,0 +1,165 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module loads PyOpenGL and provides a namespace for OpenGL.""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "25/07/2016" + + +from contextlib import contextmanager as _contextmanager +from ctypes import c_uint +import logging + +_logger = logging.getLogger(__name__) + +import OpenGL +# Set the following to true for debugging +if _logger.getEffectiveLevel() <= logging.DEBUG: + _logger.debug('Enabling PyOpenGL debug flags') + OpenGL.ERROR_LOGGING = True + OpenGL.ERROR_CHECKING = True + OpenGL.ERROR_ON_COPY = True +else: + OpenGL.ERROR_LOGGING = False + OpenGL.ERROR_CHECKING = False + OpenGL.ERROR_ON_COPY = False + +import OpenGL.GL as _GL +from OpenGL.GL import * # noqa + +# Extentions core in OpenGL 3 +from OpenGL.GL.ARB import framebuffer_object as _FBO +from OpenGL.GL.ARB.framebuffer_object import * # noqa +from OpenGL.GL.ARB.texture_rg import GL_R32F, GL_R16F # noqa +from OpenGL.GL.ARB.texture_rg import GL_R16, GL_R8 # noqa + +# PyOpenGL 3.0.1 does not define it +try: + GLchar +except NameError: + from ctypes import c_char + GLchar = c_char + + +def testGL(): + """Test if required OpenGL version and extensions are available. + + This MUST be run with an active OpenGL context. + """ + version = glGetString(GL_VERSION).split()[0] # get version number + major, minor = int(version[0]), int(version[2]) + if major < 2 or (major == 2 and minor < 1): + raise RuntimeError( + "Requires at least OpenGL version 2.1, running with %s" % version) + + from OpenGL.GL.ARB.framebuffer_object import glInitFramebufferObjectARB + from OpenGL.GL.ARB.texture_rg import glInitTextureRgARB + + if not glInitFramebufferObjectARB(): + raise RuntimeError( + "OpenGL GL_ARB_framebuffer_object extension required !") + + if not glInitTextureRgARB(): + raise RuntimeError("OpenGL GL_ARB_texture_rg extension required !") + + +# Additional setup +if hasattr(glget, 'addGLGetConstant'): + glget.addGLGetConstant(GL_FRAMEBUFFER_BINDING, (1,)) + + +@_contextmanager +def enabled(capacity, enable=True): + """Context manager enabling an OpenGL capacity. + + This is not checking the current state of the capacity. + + :param capacity: The OpenGL capacity enum to enable/disable + :param bool enable: + True (default) to enable during context, False to disable + """ + if enable: + glEnable(capacity) + yield + glDisable(capacity) + else: + glDisable(capacity) + yield + glEnable(capacity) + + +def disabled(capacity, disable=True): + """Context manager disabling an OpenGL capacity. + + This is not checking the current state of the capacity. + + :param capacity: The OpenGL capacity enum to disable/enable + :param bool disable: + True (default) to disable during context, False to enable + """ + return enabled(capacity, not disable) + + +# Additional OpenGL wrapping + +def glGetActiveAttrib(program, index): + """Wrap PyOpenGL glGetActiveAttrib""" + bufsize = glGetProgramiv(program, GL_ACTIVE_ATTRIBUTE_MAX_LENGTH) + length = GLsizei() + size = GLint() + type_ = GLenum() + name = (GLchar * bufsize)() + + _GL.glGetActiveAttrib(program, index, bufsize, length, size, type_, name) + return name.value, size.value, type_.value + + +def glDeleteRenderbuffers(buffers): + if not hasattr(buffers, '__len__'): # Support single int argument + buffers = [buffers] + length = len(buffers) + _FBO.glDeleteRenderbuffers(length, (c_uint * length)(*buffers)) + + +def glDeleteFramebuffers(buffers): + if not hasattr(buffers, '__len__'): # Support single int argument + buffers = [buffers] + length = len(buffers) + _FBO.glDeleteFramebuffers(length, (c_uint * length)(*buffers)) + + +def glDeleteBuffers(buffers): + if not hasattr(buffers, '__len__'): # Support single int argument + buffers = [buffers] + length = len(buffers) + _GL.glDeleteBuffers(length, (c_uint * length)(*buffers)) + + +def glDeleteTextures(textures): + if not hasattr(textures, '__len__'): # Support single int argument + textures = [textures] + length = len(textures) + _GL.glDeleteTextures((c_uint * length)(*textures)) diff --git a/silx/gui/_glutils/utils.py b/silx/gui/_glutils/utils.py new file mode 100644 index 0000000..73af338 --- /dev/null +++ b/silx/gui/_glutils/utils.py @@ -0,0 +1,70 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides conversion functions between OpenGL and numpy types. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "10/01/2017" + +from . import gl +import numpy + + +_GL_TYPE_SIZES = { + gl.GL_FLOAT: 4, + gl.GL_BYTE: 1, + gl.GL_SHORT: 2, + gl.GL_INT: 4, + gl.GL_UNSIGNED_BYTE: 1, + gl.GL_UNSIGNED_SHORT: 2, + gl.GL_UNSIGNED_INT: 4, +} + + +def sizeofGLType(type_): + """Returns the size in bytes of an element of type `type_`""" + return _GL_TYPE_SIZES[type_] + + +_TYPE_CONVERTER = { + numpy.dtype(numpy.float32): gl.GL_FLOAT, + numpy.dtype(numpy.int8): gl.GL_BYTE, + numpy.dtype(numpy.int16): gl.GL_SHORT, + numpy.dtype(numpy.int32): gl.GL_INT, + numpy.dtype(numpy.uint8): gl.GL_UNSIGNED_BYTE, + numpy.dtype(numpy.uint16): gl.GL_UNSIGNED_SHORT, + numpy.dtype(numpy.uint32): gl.GL_UNSIGNED_INT, +} + + +def isSupportedGLType(type_): + """Test if a numpy type or dtype can be converted to a GL type.""" + return numpy.dtype(type_) in _TYPE_CONVERTER + + +def numpyToGLType(type_): + """Returns the GL type corresponding the provided numpy type or dtype.""" + return _TYPE_CONVERTER[numpy.dtype(type_)] diff --git a/silx/gui/_utils.py b/silx/gui/_utils.py new file mode 100644 index 0000000..e29141f --- /dev/null +++ b/silx/gui/_utils.py @@ -0,0 +1,102 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides convenient functions to use with Qt objects. + +It provides conversion between numpy and QImage. +""" + +from __future__ import division + + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "16/01/2017" + + +import sys +import numpy + +from . import qt + + +def convertArrayToQImage(image): + """Convert an array-like RGB888 image to a QImage. + + The created QImage is using a copy of the array data. + + Limitation: Only supports RGB888 format. + + :param image: Array-like image data + :type image: numpy.ndarray of uint8 of dimension HxWx3 + :return: Corresponding Qt image + :rtype: QImage + """ + # Possible extension: add a format argument to support more formats + + image = numpy.array(image, copy=False, order='C', dtype=numpy.uint8) + + height, width, depth = image.shape + assert depth == 3 + + qimage = qt.QImage( + image.data, + width, + height, + image.strides[0], # bytesPerLine + qt.QImage.Format_RGB888) + + return qimage.copy() # Making a copy of the image and its data + + +def convertQImageToArray(image): + """Convert a RGB888 QImage to a numpy array. + + Limitation: Only supports RGB888 format. + If QImage is not RGB888 it gets converted to this format. + + :param QImage: The QImage to convert. + :return: The image array + :rtype: numpy.ndarray of uint8 of shape HxWx3 + """ + # Possible extension: avoid conversion to support more formats + + if image.format() != qt.QImage.Format_RGB888: + # Convert to RGB888 if needed + image = image.convertToFormat(qt.QImage.Format_RGB888) + + ptr = image.bits() + if qt.BINDING != 'PySide': + ptr.setsize(image.byteCount()) + if qt.BINDING == 'PyQt4' and sys.version_info[0] == 2: + ptr = ptr.asstring() + elif sys.version_info[0] == 3: # PySide with Python3 + ptr = ptr.tobytes() + + array = numpy.fromstring(ptr, dtype=numpy.uint8) + + # Lines are 32 bits aligned: remove padding bytes + array = array.reshape(image.height(), -1)[:, :image.width() * 3] + array.shape = image.height(), image.width(), 3 + return array diff --git a/silx/gui/console.py b/silx/gui/console.py new file mode 100644 index 0000000..13760b4 --- /dev/null +++ b/silx/gui/console.py @@ -0,0 +1,214 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides an IPython console widget. + +You can push variables - any python object - to the +console's interactive namespace. This provides users with an advanced way +of interacting with your program. For instance, if your program has a +:class:`PlotWidget` or a :class:`PlotWindow`, you can push a reference to +these widgets to allow your users to add curves, save data to files… by using +the widgets' methods from the console. + +.. note:: + + This module has a dependency on + `IPython <https://pypi.python.org/pypi/ipython>`_ and + `qtconsole <https://pypi.python.org/pypi/qtconsole>`_ (or *ipython.qt* for + older versions of *IPython*). An ``ImportError`` will be raised if it is + imported while the dependencies are not satisfied. + +Basic usage example:: + + from silx.gui import qt + from silx.gui.console import IPythonWidget + + app = qt.QApplication([]) + + hello_button = qt.QPushButton("Hello World!", None) + hello_button.show() + + console = IPythonWidget() + console.show() + console.pushVariables({"the_button": hello_button}) + + app.exec_() + +This program will display a console widget and a push button in two separate +windows. You will be able to interact with the button from the console, +for example change its text:: + + >>> the_button.setText("Spam spam") + +An IPython interactive console is a powerful tool that enables you to work +with data and plot it. +See `this tutorial <https://plot.ly/python/ipython-notebook-tutorial/>`_ +for more information on some of the rich features of IPython. +""" +__authors__ = ["Tim Rae", "V.A. Sole", "P. Knobel"] +__license__ = "MIT" +__date__ = "24/05/2016" + +import logging + +from . import qt + +_logger = logging.getLogger(__name__) + +try: + import IPython +except ImportError as e: + raise ImportError("Failed to import IPython, required by " + __name__) + +# This widget cannot be used inside an interactive IPython shell. +# It would raise MultipleInstanceError("Multiple incompatible subclass +# instances of InProcessInteractiveShell are being created"). +try: + __IPYTHON__ +except NameError: + pass # Not in IPython +else: + msg = "Module " + __name__ + " cannot be used within an IPython shell" + raise ImportError(msg) + +# qtconsole is a separate module in recent versions of IPython/Jupyter +# http://blog.jupyter.org/2015/04/15/the-big-split/ +if IPython.__version__.startswith("2"): + qtconsole = None +else: + try: + import qtconsole + except ImportError: + qtconsole = None + +if qtconsole is not None: + try: + from qtconsole.rich_ipython_widget import RichJupyterWidget as \ + RichIPythonWidget + except ImportError: + try: + from qtconsole.rich_ipython_widget import RichIPythonWidget + except ImportError as e: + qtconsole = None + else: + from qtconsole.inprocess import QtInProcessKernelManager + else: + from qtconsole.inprocess import QtInProcessKernelManager + + +if qtconsole is None: + # Import the console machinery from ipython + + # The `has_binding` test of IPython does not find the Qt bindings + # in case silx is used in a frozen binary + import IPython.external.qt_loaders + + def has_binding(*var, **kw): + return True + + IPython.external.qt_loaders.has_binding = has_binding + + from IPython.qt.console.rich_ipython_widget import RichIPythonWidget + from IPython.qt.inprocess import QtInProcessKernelManager + + +class IPythonWidget(RichIPythonWidget): + """Live IPython console widget. + + :param custom_banner: Custom welcome message to be printed at the top of + the console. + """ + + def __init__(self, parent=None, custom_banner=None, *args, **kwargs): + if parent is not None: + kwargs["parent"] = parent + super(IPythonWidget, self).__init__(*args, **kwargs) + if custom_banner is not None: + self.banner = custom_banner + self.setWindowTitle(self.banner) + self.kernel_manager = kernel_manager = QtInProcessKernelManager() + kernel_manager.start_kernel() + self.kernel_client = kernel_client = self._kernel_manager.client() + kernel_client.start_channels() + + def stop(): + kernel_client.stop_channels() + kernel_manager.shutdown_kernel() + self.exit_requested.connect(stop) + + def sizeHint(self): + """Return a reasonable default size for usage in :class:`PlotWindow`""" + return qt.QSize(500, 300) + + def pushVariables(self, variable_dict): + """ Given a dictionary containing name / value pairs, push those + variables to the IPython console widget. + + :param variable_dict: Dictionary of variables to be pushed to the + console's interactive namespace (```{variable_name: object, …}```) + """ + self.kernel_manager.kernel.shell.push(variable_dict) + + +class IPythonDockWidget(qt.QDockWidget): + """Dock Widget including a :class:`IPythonWidget` inside + a vertical layout. + + :param available_vars: Dictionary of variables to be pushed to the + console's interactive namespace: ``{"variable_name": object, …}`` + :param custom_banner: Custom welcome message to be printed at the top of + the console + :param title: Dock widget title + :param parent: Parent :class:`qt.QMainWindow` containing this + :class:`qt.QDockWidget` + """ + def __init__(self, parent=None, available_vars=None, custom_banner=None, + title="Console"): + super(IPythonDockWidget, self).__init__(title, parent) + + self.ipyconsole = IPythonWidget(custom_banner=custom_banner) + + self.layout().setContentsMargins(0, 0, 0, 0) + self.setWidget(self.ipyconsole) + + if available_vars is not None: + self.ipyconsole.pushVariables(available_vars) + + def showEvent(self, event): + """Make sure this widget is raised when it is shown + (when it is first created as a tab in PlotWindow or when it is shown + again after hiding). + """ + self.raise_() + + +def main(): + """Run a Qt app with an IPython console""" + app = qt.QApplication([]) + widget = IPythonDockWidget() + widget.show() + app.exec_() + +if __name__ == '__main__': + main() diff --git a/silx/gui/data/ArrayTableModel.py b/silx/gui/data/ArrayTableModel.py new file mode 100644 index 0000000..87a2fc1 --- /dev/null +++ b/silx/gui/data/ArrayTableModel.py @@ -0,0 +1,610 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +This module defines a data model for displaying and editing arrays of any +number of dimensions in a table view. +""" +from __future__ import division +import numpy +import logging +from silx.gui import qt +from silx.gui.data.TextFormatter import TextFormatter + +__authors__ = ["V.A. Sole"] +__license__ = "MIT" +__date__ = "24/01/2017" + + +_logger = logging.getLogger(__name__) + + +def _is_array(data): + """Return True if object implements all necessary attributes to be used + as a numpy array. + + :param object data: Array-like object (numpy array, h5py dataset...) + :return: boolean + """ + # add more required attribute if necessary + for attr in ("shape", "dtype"): + if not hasattr(data, attr): + return False + return True + + +class ArrayTableModel(qt.QAbstractTableModel): + """This data model provides access to 2D slices in a N-dimensional + array. + + A slice for a 3-D array is characterized by a perspective (the number of + the axis orthogonal to the slice) and an index at which the slice + intersects the orthogonal axis. + + In the n-D case, only slices parallel to the last two axes are handled. A + slice is therefore characterized by a list of indices locating the + slice on all the :math:`n - 2` orthogonal axes. + + :param parent: Parent QObject + :param data: Numpy array, or object implementing a similar interface + (e.g. h5py dataset) + :param str fmt: Format string for representing numerical values. + Default is ``"%g"``. + :param sequence[int] perspective: See documentation + of :meth:`setPerspective`. + """ + def __init__(self, parent=None, data=None, perspective=None): + qt.QAbstractTableModel.__init__(self, parent) + + self._array = None + """n-dimensional numpy array""" + + self._bgcolors = None + """(n+1)-dimensional numpy array containing RGB(A) color data + for the background color + """ + + self._fgcolors = None + """(n+1)-dimensional numpy array containing RGB(A) color data + for the foreground color + """ + + self._formatter = None + """Formatter for text representation of data""" + + formatter = TextFormatter(self) + formatter.setUseQuoteForText(False) + self.setFormatter(formatter) + + self._index = None + """This attribute stores the slice index, as a list of indices + where the frame intersects orthogonal axis.""" + + self._perspective = None + """Sequence of dimensions orthogonal to the frame to be viewed. + For an array with ``n`` dimensions, this is a sequence of ``n-2`` + integers. the first dimension is numbered ``0``. + By default, the data frames use the last two dimensions as their axes + and therefore the perspective is a sequence of the first ``n-2`` + dimensions. + For example, for a 5-D array, the default perspective is ``(0, 1, 2)`` + and the default frames axes are ``(3, 4)``.""" + + # set _data and _perspective + self.setArrayData(data, perspective=perspective) + + def _getRowDim(self): + """The row axis is the first axis parallel to the frames + (lowest dimension number) + + Return None for 0-D (scalar) or 1-D arrays + """ + n_dimensions = len(self._array.shape) + if n_dimensions < 2: + # scalar or 1D array: no row index + return None + # take all dimensions and remove the orthogonal ones + frame_axes = set(range(0, n_dimensions)) - set(self._perspective) + # sanity check + assert len(frame_axes) == 2 + return min(frame_axes) + + def _getColumnDim(self): + """The column axis is the second (highest dimension) axis parallel + to the frames + + Return None for 0-D (scalar) + """ + n_dimensions = len(self._array.shape) + if n_dimensions < 1: + # scalar: no column index + return None + frame_axes = set(range(0, n_dimensions)) - set(self._perspective) + # sanity check + assert (len(frame_axes) == 2) if n_dimensions > 1 else (len(frame_axes) == 1) + return max(frame_axes) + + def _getIndexTuple(self, table_row, table_col): + """Return the n-dimensional index of a value in the original array, + based on its row and column indices in the table view + + :param table_row: Row index (0-based) of a table cell + :param table_col: Column index (0-based) of a table cell + :return: Tuple of indices of the element in the numpy array + """ + row_dim = self._getRowDim() + col_dim = self._getColumnDim() + + # get indices on all orthogonal axes + selection = list(self._index) + # insert indices on parallel axes + if row_dim is not None: + selection.insert(row_dim, table_row) + if col_dim is not None: + selection.insert(col_dim, table_col) + return tuple(selection) + + # Methods to be implemented to subclass QAbstractTableModel + def rowCount(self, parent_idx=None): + """QAbstractTableModel method + Return number of rows to be displayed in table""" + row_dim = self._getRowDim() + if row_dim is None: + # 0-D and 1-D arrays + return 1 + return self._array.shape[row_dim] + + def columnCount(self, parent_idx=None): + """QAbstractTableModel method + Return number of columns to be displayed in table""" + col_dim = self._getColumnDim() + if col_dim is None: + # 0-D array + return 1 + return self._array.shape[col_dim] + + def data(self, index, role=qt.Qt.DisplayRole): + """QAbstractTableModel method to access data values + in the format ready to be displayed""" + if index.isValid(): + selection = self._getIndexTuple(index.row(), + index.column()) + if role == qt.Qt.DisplayRole: + return self._formatter.toString(self._array[selection]) + + if role == qt.Qt.BackgroundRole and self._bgcolors is not None: + r, g, b = self._bgcolors[selection][0:3] + if self._bgcolors.shape[-1] == 3: + return qt.QColor(r, g, b) + if self._bgcolors.shape[-1] == 4: + a = self._bgcolors[selection][3] + return qt.QColor(r, g, b, a) + + if role == qt.Qt.ForegroundRole: + if self._fgcolors is not None: + r, g, b = self._fgcolors[selection][0:3] + if self._fgcolors.shape[-1] == 3: + return qt.QColor(r, g, b) + if self._fgcolors.shape[-1] == 4: + a = self._fgcolors[selection][3] + return qt.QColor(r, g, b, a) + + # no fg color given, use black or white + # based on luminosity threshold + elif self._bgcolors is not None: + r, g, b = self._bgcolors[selection][0:3] + lum = 0.21 * r + 0.72 * g + 0.07 * b + if lum < 128: + return qt.QColor(qt.Qt.white) + else: + return qt.QColor(qt.Qt.black) + + def headerData(self, section, orientation, role=qt.Qt.DisplayRole): + """QAbstractTableModel method + Return the 0-based row or column index, for display in the + horizontal and vertical headers""" + if role == qt.Qt.DisplayRole: + if orientation == qt.Qt.Vertical: + return "%d" % section + if orientation == qt.Qt.Horizontal: + return "%d" % section + return None + + def flags(self, index): + """QAbstractTableModel method to inform the view whether data + is editable or not.""" + if not self._editable: + return qt.QAbstractTableModel.flags(self, index) + return qt.QAbstractTableModel.flags(self, index) | qt.Qt.ItemIsEditable + + def setData(self, index, value, role=None): + """QAbstractTableModel method to handle editing data. + Cast the new value into the same format as the array before editing + the array value.""" + if index.isValid() and role == qt.Qt.EditRole: + try: + # cast value to same type as array + v = numpy.asscalar( + numpy.array(value, dtype=self._array.dtype)) + except ValueError: + return False + + selection = self._getIndexTuple(index.row(), + index.column()) + self._array[selection] = v + self.dataChanged.emit(index, index) + return True + else: + return False + + # Public methods + def setArrayData(self, data, copy=True, + perspective=None, editable=False): + """Set the data array and the viewing perspective. + + You can set ``copy=False`` if you need more performances, when dealing + with a large numpy array. In this case, a simple reference to the data + is used to access the data, rather than a copy of the array. + + .. warning:: + + Any change to the data model will affect your original data + array, when using a reference rather than a copy.. + + :param data: n-dimensional numpy array, or any object that can be + converted to a numpy array using ``numpy.array(data)`` (e.g. + a nested sequence). + :param bool copy: If *True* (default), a copy of the array is stored + and the original array is not modified if the table is edited. + If *False*, then the behavior depends on the data type: + if possible (if the original array is a proper numpy array) + a reference to the original array is used. + :param perspective: See documentation of :meth:`setPerspective`. + If None, the default perspective is the list of the first ``n-2`` + dimensions, to view frames parallel to the last two axes. + :param bool editable: Flag to enable editing data. Default *False*. + """ + if qt.qVersion() > "4.6": + self.beginResetModel() + else: + self.reset() + + if data is None: + # empty array + self._array = numpy.array([]) + elif copy: + # copy requested (default) + self._array = numpy.array(data, copy=True) + elif not _is_array(data): + raise TypeError("data is not a proper array. Try setting" + + " copy=True to convert it into a numpy array" + + " (this will cause the data to be copied!)") + # # copy not requested, but necessary + # _logger.warning( + # "data is not an array-like object. " + + # "Data must be copied.") + # self._array = numpy.array(data, copy=True) + else: + # Copy explicitly disabled & data implements required attributes. + # We can use a reference. + self._array = data + + # reset colors to None if new data shape is inconsistent + valid_color_shapes = (self._array.shape + (3,), + self._array.shape + (4,)) + if self._bgcolors is not None: + if self._bgcolors.shape not in valid_color_shapes: + self._bgcolors = None + if self._fgcolors is not None: + if self._fgcolors.shape not in valid_color_shapes: + self._fgcolors = None + + self.setEditable(editable) + + self._index = [0 for _i in range((len(self._array.shape) - 2))] + self._perspective = tuple(perspective) if perspective is not None else\ + tuple(range(0, len(self._array.shape) - 2)) + + if qt.qVersion() > "4.6": + self.endResetModel() + + def setArrayColors(self, bgcolors=None, fgcolors=None): + """Set the colors for all table cells by passing an array + of RGB or RGBA values (integers between 0 and 255). + + The shape of the colors array must be consistent with the data shape. + + If the data array is n-dimensional, the colors array must be + (n+1)-dimensional, with the first n-dimensions identical to the data + array dimensions, and the last dimension length-3 (RGB) or + length-4 (RGBA). + + :param bgcolors: RGB or RGBA colors array, defining the background color + for each cell in the table. + :param fgcolors: RGB or RGBA colors array, defining the foreground color + (text color) for each cell in the table. + """ + # array must be RGB or RGBA + valid_shapes = (self._array.shape + (3,), self._array.shape + (4,)) + errmsg = "Inconsistent shape for color array, should be %s or %s" % valid_shapes + + if bgcolors is not None: + if not _is_array(bgcolors): + bgcolors = numpy.array(bgcolors) + assert bgcolors.shape in valid_shapes, errmsg + + self._bgcolors = bgcolors + + if fgcolors is not None: + if not _is_array(fgcolors): + fgcolors = numpy.array(fgcolors) + assert fgcolors.shape in valid_shapes, errmsg + + self._fgcolors = fgcolors + + def setEditable(self, editable): + """Set flags to make the data editable. + + .. warning:: + + If the data is a reference to a h5py dataset open in read-only + mode, setting *editable=True* will fail and print a warning. + + .. warning:: + + Making the data editable means that the underlying data structure + in this data model will be modified. + If the data is a reference to a public object (open with + ``copy=False``), this could have side effects. If it is a + reference to an HDF5 dataset, this means the file will be + modified. + + :param bool editable: Flag to enable editing data. + :return: True if setting desired flag succeeded, False if it failed. + """ + self._editable = editable + if hasattr(self._array, "file"): + if hasattr(self._array.file, "mode"): + if editable and self._array.file.mode == "r": + _logger.warning( + "Data is a HDF5 dataset open in read-only " + + "mode. Editing must be disabled.") + self._editable = False + return False + return True + + def getData(self, copy=True): + """Return a copy of the data array, or a reference to it + if *copy=False* is passed as parameter. + + In case the shape was modified, to convert 0-D or 1-D data + into 2-D data, the original shape is restored in the returned data. + + :param bool copy: If *True* (default), return a copy of the data. If + *False*, return a reference. + :return: numpy array of data, or reference to original data object + if *copy=False* + """ + data = self._array if not copy else numpy.array(self._array, copy=True) + return data + + def setFrameIndex(self, index): + """Set the active slice index. + + This method is only relevant to arrays with at least 3 dimensions. + + :param index: Index of the active slice in the array. + In the general n-D case, this is a sequence of :math:`n - 2` + indices where the slice intersects the respective orthogonal axes. + :raise IndexError: If any index in the index sequence is out of bound + on its respective axis. + """ + shape = self._array.shape + if len(shape) < 3: + # index is ignored + return + + if qt.qVersion() > "4.6": + self.beginResetModel() + else: + self.reset() + + if len(shape) == 3: + len_ = shape[self._perspective[0]] + # accept integers as index in the case of 3-D arrays + if not hasattr(index, "__len__"): + self._index = [index] + else: + self._index = index + if not 0 <= self._index[0] < len_: + raise ValueError("Index must be a positive integer " + + "lower than %d" % len_) + else: + # general n-D case + for i_, idx in enumerate(index): + if not 0 <= idx < shape[self._perspective[i_]]: + raise IndexError("Invalid index %d " % idx + + "not in range 0-%d" % (shape[i_] - 1)) + self._index = index + + if qt.qVersion() > "4.6": + self.endResetModel() + + def setFormatter(self, formatter): + """Set the formatter object to be used to display data from the model + + :param TextFormatter formatter: Formatter to use + """ + if formatter is self._formatter: + return + + if qt.qVersion() > "4.6": + self.beginResetModel() + + if self._formatter is not None: + self._formatter.formatChanged.disconnect(self.__formatChanged) + + self._formatter = formatter + if self._formatter is not None: + self._formatter.formatChanged.connect(self.__formatChanged) + + if qt.qVersion() > "4.6": + self.endResetModel() + else: + self.reset() + + def getFormatter(self): + """Returns the text formatter used. + + :rtype: TextFormatter + """ + return self._formatter + + def __formatChanged(self): + """Called when the format changed. + """ + self.reset() + + def setPerspective(self, perspective): + """Set the perspective by defining a sequence listing all axes + orthogonal to the frame or 2-D slice to be visualized. + + Alternatively, you can use :meth:`setFrameAxes` for the complementary + approach of specifying the two axes parallel to the frame. + + In the 1-D or 2-D case, this parameter is irrelevant. + + In the 3-D case, if the unit vectors describing + your axes are :math:`\vec{x}, \vec{y}, \vec{z}`, a perspective of 0 + means you slices are parallel to :math:`\vec{y}\vec{z}`, 1 means they + are parallel to :math:`\vec{x}\vec{z}` and 2 means they + are parallel to :math:`\vec{x}\vec{y}`. + + In the n-D case, this parameter is a sequence of :math:`n-2` axes + numbers. + For instance if you want to display 2-D frames whose axes are the + second and third dimensions of a 5-D array, set the perspective to + ``(0, 3, 4)``. + + :param perspective: Sequence of dimensions/axes orthogonal to the + frames. + :raise: IndexError if any value in perspective is higher than the + number of dimensions minus one (first dimension is 0), or + if the number of values is different from the number of dimensions + minus two. + """ + n_dimensions = len(self._array.shape) + if n_dimensions < 3: + _logger.warning( + "perspective is not relevant for 1D and 2D arrays") + return + + if not hasattr(perspective, "__len__"): + # we can tolerate an integer for 3-D array + if n_dimensions == 3: + perspective = [perspective] + else: + raise ValueError("perspective must be a sequence of integers") + + # ensure unicity of dimensions in perspective + perspective = tuple(set(perspective)) + + if len(perspective) != n_dimensions - 2 or\ + min(perspective) < 0 or max(perspective) >= n_dimensions: + raise IndexError( + "Invalid perspective " + str(perspective) + + " for %d-D array " % n_dimensions + + "with shape " + str(self._array.shape)) + + if qt.qVersion() > "4.6": + self.beginResetModel() + else: + self.reset() + + self._perspective = perspective + + # reset index + self._index = [0 for _i in range(n_dimensions - 2)] + + if qt.qVersion() > "4.6": + self.endResetModel() + + def setFrameAxes(self, row_axis, col_axis): + """Set the perspective by specifying the two axes parallel to the frame + to be visualised. + + The complementary approach of defining the orthogonal axes can be used + with :meth:`setPerspective`. + + :param int row_axis: Index (0-based) of the first dimension used as a frame + axis + :param int col_axis: Index (0-based) of the 2nd dimension used as a frame + axis + :raise: IndexError if axes are invalid + """ + if row_axis > col_axis: + _logger.warning("The dimension of the row axis must be lower " + + "than the dimension of the column axis. Swapping.") + row_axis, col_axis = min(row_axis, col_axis), max(row_axis, col_axis) + + n_dimensions = len(self._array.shape) + if n_dimensions < 3: + _logger.warning( + "Frame axes cannot be changed for 1D and 2D arrays") + return + + perspective = tuple(set(range(0, n_dimensions)) - {row_axis, col_axis}) + + if len(perspective) != n_dimensions - 2 or\ + min(perspective) < 0 or max(perspective) >= n_dimensions: + raise IndexError( + "Invalid perspective " + str(perspective) + + " for %d-D array " % n_dimensions + + "with shape " + str(self._array.shape)) + + if qt.qVersion() > "4.6": + self.beginResetModel() + else: + self.reset() + + self._perspective = perspective + # reset index + self._index = [0 for _i in range(n_dimensions - 2)] + + if qt.qVersion() > "4.6": + self.endResetModel() + + +if __name__ == "__main__": + app = qt.QApplication([]) + w = qt.QTableView() + d = numpy.random.normal(0, 1, (5, 1000, 1000)) + for i in range(5): + d[i, :, :] += i * 10 + m = ArrayTableModel(data=d) + w.setModel(m) + m.setFrameIndex(3) + # m.setArrayData(numpy.ones((100,))) + w.show() + app.exec_() diff --git a/silx/gui/data/ArrayTableWidget.py b/silx/gui/data/ArrayTableWidget.py new file mode 100644 index 0000000..ba3fa11 --- /dev/null +++ b/silx/gui/data/ArrayTableWidget.py @@ -0,0 +1,490 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module defines a widget designed to display data arrays with any +number of dimensions as 2D frames (images, slices) in a table view. +The dimensions not displayed in the table can be browsed using improved +sliders. + +The widget uses a TableView that relies on a custom abstract item +model: :class:`silx.gui.data.ArrayTableModel`. +""" +from __future__ import division +import sys + +from silx.gui import qt +from silx.gui.widgets.TableWidget import TableView +from .ArrayTableModel import ArrayTableModel +from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser + +__authors__ = ["V.A. Sole", "P. Knobel"] +__license__ = "MIT" +__date__ = "24/01/2017" + + +class AxesSelector(qt.QWidget): + """Widget with two combo-boxes to select two dimensions among + all possible dimensions of an n-dimensional array. + + The first combobox contains values from :math:`0` to :math:`n-2`. + + The choices in the 2nd CB depend on the value selected in the first one. + If the value selected in the first CB is :math:`m`, the second one lets you + select values from :math:`m+1` to :math:`n-1`. + + The two axes can be used to select the row axis and the column axis t + display a slice of the array data in a table view. + """ + sigDimensionsChanged = qt.Signal(int, int) + """Signal emitted whenever one of the comboboxes is changed. + The signal carries the two selected dimensions.""" + + def __init__(self, parent=None, n=None): + qt.QWidget.__init__(self, parent) + self.layout = qt.QHBoxLayout(self) + self.layout.setContentsMargins(0, 2, 0, 2) + self.layout.setSpacing(10) + + self.rowsCB = qt.QComboBox(self) + self.columnsCB = qt.QComboBox(self) + + self.layout.addWidget(qt.QLabel("Rows dimension", self)) + self.layout.addWidget(self.rowsCB) + self.layout.addWidget(qt.QLabel(" ", self)) + self.layout.addWidget(qt.QLabel("Columns dimension", self)) + self.layout.addWidget(self.columnsCB) + self.layout.addStretch(1) + + self._slotsAreConnected = False + if n is not None: + self.setNDimensions(n) + + def setNDimensions(self, n): + """Initialize combo-boxes depending on number of dimensions of array. + Initially, the rows dimension is the second-to-last one, and the + columns dimension is the last one. + + Link the CBs together. MAke them emit a signal when their value is + changed. + + :param int n: Number of dimensions of array + """ + # remember the number of dimensions and the rows dimension + self.n = n + self._rowsDim = n - 2 + + # ensure slots are disconnected before (re)initializing widget + if self._slotsAreConnected: + self.rowsCB.currentIndexChanged.disconnect(self._rowDimChanged) + self.columnsCB.currentIndexChanged.disconnect(self._colDimChanged) + + self._clear() + self.rowsCB.addItems([str(i) for i in range(n - 1)]) + self.rowsCB.setCurrentIndex(n - 2) + if n >= 1: + self.columnsCB.addItem(str(n - 1)) + self.columnsCB.setCurrentIndex(0) + + # reconnect slots + self.rowsCB.currentIndexChanged.connect(self._rowDimChanged) + self.columnsCB.currentIndexChanged.connect(self._colDimChanged) + self._slotsAreConnected = True + + # emit new dimensions + if n > 2: + self.sigDimensionsChanged.emit(n - 2, n - 1) + + def setDimensions(self, row_dim, col_dim): + """Set the rows and columns dimensions. + + The rows dimension must be lower than the columns dimension. + + :param int row_dim: Rows dimension + :param int col_dim: Columns dimension + """ + if row_dim >= col_dim: + raise IndexError("Row dimension must be lower than column dimension") + if not (0 <= row_dim < self.n - 1): + raise IndexError("Row dimension must be between 0 and %d" % (self.n - 2)) + if not (row_dim < col_dim <= self.n - 1): + raise IndexError("Col dimension must be between %d and %d" % (row_dim + 1, self.n - 1)) + + # set the rows dimension; this triggers an update of columnsCB + self.rowsCB.setCurrentIndex(row_dim) + # columnsCB first item is "row_dim + 1". So index of "col_dim" is + # col_dim - (row_dim + 1) + self.columnsCB.setCurrentIndex(col_dim - row_dim - 1) + + def getDimensions(self): + """Return a 2-tuple of the rows dimension and the columns dimension. + + :return: 2-tuple of axes numbers (row_dimension, col_dimension) + """ + return self._getRowDim(), self._getColDim() + + def _clear(self): + """Empty the combo-boxes""" + self.rowsCB.clear() + self.columnsCB.clear() + + def _getRowDim(self): + """Get rows dimension, selected in :attr:`rowsCB` + """ + # rows combobox contains elements "0", ..."n-2", + # so the selected dim is always equal to the index + return self.rowsCB.currentIndex() + + def _getColDim(self): + """Get columns dimension, selected in :attr:`columnsCB`""" + # columns combobox contains elements "row_dim+1", "row_dim+2", ..., "n-1" + # so the selected dim is equal to row_dim + 1 + index + return self._rowsDim + 1 + self.columnsCB.currentIndex() + + def _rowDimChanged(self): + """Update columns combobox when the rows dimension is changed. + + Emit :attr:`sigDimensionsChanged`""" + old_col_dim = self._getColDim() + new_row_dim = self._getRowDim() + + # clear cols CB + self.columnsCB.currentIndexChanged.disconnect(self._colDimChanged) + self.columnsCB.clear() + # refill cols CB + for i in range(new_row_dim + 1, self.n): + self.columnsCB.addItem(str(i)) + + # keep previous col dimension if possible + new_col_cb_idx = old_col_dim - (new_row_dim + 1) + if new_col_cb_idx < 0: + # if row_dim is now greater than the previous col_dim, + # we select a new col_dim = row_dim + 1 (first element in cols CB) + new_col_cb_idx = 0 + self.columnsCB.setCurrentIndex(new_col_cb_idx) + + # reconnect slot + self.columnsCB.currentIndexChanged.connect(self._colDimChanged) + + self._rowsDim = new_row_dim + + self.sigDimensionsChanged.emit(self._getRowDim(), self._getColDim()) + + def _colDimChanged(self): + """Emit :attr:`sigDimensionsChanged`""" + self.sigDimensionsChanged.emit(self._getRowDim(), self._getColDim()) + + +def _get_shape(array_like): + """Return shape of an array like object. + + In case the object is a nested sequence (list of lists, tuples...), + the size of each dimension is assumed to be uniform, and is deduced from + the length of the first sequence. + + :param array_like: Array like object: numpy array, hdf5 dataset, + multi-dimensional sequence + :return: Shape of array, as a tuple of integers + """ + if hasattr(array_like, "shape"): + return array_like.shape + + shape = [] + subsequence = array_like + while hasattr(subsequence, "__len__"): + shape.append(len(subsequence)) + subsequence = subsequence[0] + + return tuple(shape) + + +class ArrayTableWidget(qt.QWidget): + """This widget is designed to display data of 2D frames (images, slices) + in a table view. The widget can load any n-dimensional array, and display + any 2-D frame/slice in the array. + + The index of the dimensions orthogonal to the displayed frame can be set + interactively using a browser widget (sliders, buttons and text entries). + + To set the data, use :meth:`setArrayData`. + To select the perspective, use :meth:`setPerspective` or + use :meth:`setFrameAxes`. + To select the frame, use :meth:`setFrameIndex`. + """ + def __init__(self, parent=None): + """ + + :param parent: parent QWidget + :param labels: list of labels for each dimension of the array + """ + qt.QWidget.__init__(self, parent) + self.mainLayout = qt.QVBoxLayout(self) + self.mainLayout.setContentsMargins(0, 0, 0, 0) + self.mainLayout.setSpacing(0) + + self.browserContainer = qt.QWidget(self) + self.browserLayout = qt.QGridLayout(self.browserContainer) + self.browserLayout.setContentsMargins(0, 0, 0, 0) + self.browserLayout.setSpacing(0) + + self._dimensionLabelsText = [] + """List of text labels sorted in the increasing order of the dimension + they apply to.""" + self._browserLabels = [] + """List of QLabel widgets.""" + self._browserWidgets = [] + """List of HorizontalSliderWithBrowser widgets.""" + + self.axesSelector = AxesSelector(self) + + self.view = TableView(self) + + self.mainLayout.addWidget(self.browserContainer) + self.mainLayout.addWidget(self.axesSelector) + self.mainLayout.addWidget(self.view) + + self.model = ArrayTableModel(self) + self.view.setModel(self.model) + + def setArrayData(self, data, labels=None, copy=True, editable=False): + """Set the data array. Update frame browsers and labels. + + :param data: Numpy array or similar object (e.g. nested sequence, + h5py dataset...) + :param labels: list of labels for each dimension of the array, or + boolean ``True`` to use default labels ("dimension 0", + "dimension 1", ...). `None` to disable labels (default). + :param bool copy: If *True*, store a copy of *data* in the model. If + *False*, store a reference to *data* if possible (only possible if + *data* is a proper numpy array or an object that implements the + same methods). + :param bool editable: Flag to enable editing data. Default is *False* + """ + self._data_shape = _get_shape(data) + + n_widgets = len(self._browserWidgets) + n_dimensions = len(self._data_shape) + + # Reset text of labels + self._dimensionLabelsText = [] + for i in range(n_dimensions): + if labels in [True, 1]: + label_text = "Dimension %d" % i + elif labels is None or i >= len(labels): + label_text = "" + else: + label_text = labels[i] + self._dimensionLabelsText.append(label_text) + + # not enough widgets, create new ones (we need n_dim - 2) + for i in range(n_widgets, n_dimensions - 2): + browser = HorizontalSliderWithBrowser(self.browserContainer) + self.browserLayout.addWidget(browser, i, 1) + self._browserWidgets.append(browser) + browser.valueChanged.connect(self._browserSlot) + browser.setEnabled(False) + browser.hide() + + label = qt.QLabel(self.browserContainer) + self._browserLabels.append(label) + self.browserLayout.addWidget(label, i, 0) + label.hide() + + n_widgets = len(self._browserWidgets) + for i in range(n_widgets): + label = self._browserLabels[i] + browser = self._browserWidgets[i] + + if (i + 2) < n_dimensions: + label.setText(self._dimensionLabelsText[i]) + browser.setRange(0, self._data_shape[i] - 1) + browser.setEnabled(True) + browser.show() + if labels is not None: + label.show() + else: + label.hide() + else: + browser.setEnabled(False) + browser.hide() + label.hide() + + # set model + self.model.setArrayData(data, copy=copy, editable=editable) + # some linux distributions need this call + self.view.setModel(self.model) + if editable: + self.view.enableCut() + self.view.enablePaste() + + # initialize & connect axesSelector + self.axesSelector.setNDimensions(n_dimensions) + self.axesSelector.sigDimensionsChanged.connect(self.setFrameAxes) + + def setArrayColors(self, bgcolors=None, fgcolors=None): + """Set the colors for all table cells by passing an array + of RGB or RGBA values (integers between 0 and 255). + + The shape of the colors array must be consistent with the data shape. + + If the data array is n-dimensional, the colors array must be + (n+1)-dimensional, with the first n-dimensions identical to the data + array dimensions, and the last dimension length-3 (RGB) or + length-4 (RGBA). + + :param bgcolors: RGB or RGBA colors array, defining the background color + for each cell in the table. + :param fgcolors: RGB or RGBA colors array, defining the foreground color + (text color) for each cell in the table. + """ + self.model.setArrayColors(bgcolors, fgcolors) + + def displayAxesSelector(self, isVisible): + """Allow to display or hide the axes selector. + + :param bool isVisible: True to display the axes selector. + """ + self.axesSelector.setVisible(isVisible) + + def setFrameIndex(self, index): + """Set the active slice/image index in the n-dimensional array. + + A frame is a 2D array extracted from an array. This frame is + necessarily parallel to 2 axes, and orthogonal to all other axes. + + The index of a frame is a sequence of indices along the orthogonal + axes, where the frame intersects the respective axis. The indices + are listed in the same order as the corresponding dimensions of the + data array. + + For example, it the data array has 5 dimensions, and we are + considering frames whose parallel axes are the 2nd and 4th dimensions + of the array, the frame index will be a sequence of length 3 + corresponding to the indices where the frame intersects the 1st, 3rd + and 5th axes. + + :param index: Sequence of indices defining the active data slice in + a n-dimensional array. The sequence length is :math:`n-2` + :raise: IndexError if any index in the index sequence is out of bound + on its respective axis. + """ + self.model.setFrameIndex(index) + + def _resetBrowsers(self, perspective): + """Adjust limits for browsers based on the perspective and the + size of the corresponding dimensions. Reset the index to 0. + Update the dimension in the labels. + + :param perspective: Sequence of axes/dimensions numbers (0-based) + defining the axes orthogonal to the frame. + """ + # for 3D arrays we can accept an int rather than a 1-tuple + if not hasattr(perspective, "__len__"): + perspective = [perspective] + + # perspective must be sorted + perspective = sorted(perspective) + + n_dimensions = len(self._data_shape) + for i in range(n_dimensions - 2): + browser = self._browserWidgets[i] + label = self._browserLabels[i] + browser.setRange(0, self._data_shape[perspective[i]] - 1) + browser.setValue(0) + label.setText(self._dimensionLabelsText[perspective[i]]) + + def setPerspective(self, perspective): + """Set the *perspective* by specifying which axes are orthogonal + to the frame. + + For the opposite approach (defining parallel axes), use + :meth:`setFrameAxes` instead. + + :param perspective: Sequence of unique axes numbers (0-based) defining + the orthogonal axes. For a n-dimensional array, the sequence + length is :math:`n-2`. The order is of the sequence is not taken + into account (the dimensions are displayed in increasing order + in the widget). + """ + self.model.setPerspective(perspective) + self._resetBrowsers(perspective) + + def setFrameAxes(self, row_axis, col_axis): + """Set the *perspective* by specifying which axes are parallel + to the frame. + + For the opposite approach (defining orthogonal axes), use + :meth:`setPerspective` instead. + + :param int row_axis: Index (0-based) of the first dimension used as a frame + axis + :param int col_axis: Index (0-based) of the 2nd dimension used as a frame + axis + """ + self.model.setFrameAxes(row_axis, col_axis) + n_dimensions = len(self._data_shape) + perspective = tuple(set(range(0, n_dimensions)) - {row_axis, col_axis}) + self._resetBrowsers(perspective) + + def _browserSlot(self, value): + index = [] + for browser in self._browserWidgets: + if browser.isEnabled(): + index.append(browser.value()) + self.setFrameIndex(index) + self.view.reset() + + def getData(self, copy=True): + """Return a copy of the data array, or a reference to it if + *copy=False* is passed as parameter. + + :param bool copy: If *True* (default), return a copy of the data. If + *False*, return a reference. + :return: Numpy array of data, or reference to original data object + if *copy=False* + """ + return self.model.getData(copy=copy) + + +def main(): + import numpy + a = qt.QApplication([]) + d = numpy.random.normal(0, 1, (4, 5, 1000, 1000)) + for j in range(4): + for i in range(5): + d[j, i, :, :] += i + 10 * j + w = ArrayTableWidget() + if "2" in sys.argv: + print("sending a single image") + w.setArrayData(d[0, 0]) + elif "3" in sys.argv: + print("sending 5 images") + w.setArrayData(d[0]) + else: + print("sending 4 * 5 images ") + w.setArrayData(d, labels=True) + w.show() + a.exec_() + +if __name__ == "__main__": + main() diff --git a/silx/gui/data/DataViewer.py b/silx/gui/data/DataViewer.py new file mode 100644 index 0000000..3a3ac64 --- /dev/null +++ b/silx/gui/data/DataViewer.py @@ -0,0 +1,464 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module defines a widget designed to display data using to most adapted +view from available ones from silx. +""" +from __future__ import division + +from silx.gui.data import DataViews +from silx.gui.data.DataViews import _normalizeData +import logging +from silx.gui import qt +from silx.gui.data.NumpyAxesSelector import NumpyAxesSelector + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "26/04/2017" + + +_logger = logging.getLogger(__name__) + + +class DataViewer(qt.QFrame): + """Widget to display any kind of data + + .. image:: img/DataViewer.png + + The method :meth:`setData` allows to set any data to the widget. Mostly + `numpy.array` and `h5py.Dataset` are supported with adapted views. Other + data types are displayed using a text viewer. + + A default view is automatically selected when a data is set. The method + :meth:`setDisplayMode` allows to change the view. To have a graphical tool + to select the view, prefer using the widget :class:`DataViewerFrame`. + + The dimension of the input data and the expected dimension of the selected + view can differ. For example you can display an image (2D) from 4D + data. In this case a :class:`NumpyAxesSelector` is displayed to allow the + user to select the axis mapping and the slicing of other axes. + + .. code-block:: python + + import numpy + data = numpy.random.rand(500,500) + viewer = DataViewer() + viewer.setData(data) + viewer.setVisible(True) + """ + + EMPTY_MODE = 0 + PLOT1D_MODE = 10 + PLOT2D_MODE = 20 + PLOT3D_MODE = 30 + RAW_MODE = 40 + RAW_ARRAY_MODE = 41 + RAW_RECORD_MODE = 42 + RAW_SCALAR_MODE = 43 + STACK_MODE = 50 + HDF5_MODE = 60 + + displayedViewChanged = qt.Signal(object) + """Emitted when the displayed view changes""" + + dataChanged = qt.Signal() + """Emitted when the data changes""" + + currentAvailableViewsChanged = qt.Signal() + """Emitted when the current available views (which support the current + data) change""" + + def __init__(self, parent=None): + """Constructor + + :param QWidget parent: The parent of the widget + """ + super(DataViewer, self).__init__(parent) + + self.__stack = qt.QStackedWidget(self) + self.__numpySelection = NumpyAxesSelector(self) + self.__numpySelection.selectedAxisChanged.connect(self.__numpyAxisChanged) + self.__numpySelection.selectionChanged.connect(self.__numpySelectionChanged) + self.__numpySelection.customAxisChanged.connect(self.__numpyCustomAxisChanged) + + self.setLayout(qt.QVBoxLayout(self)) + self.layout().addWidget(self.__stack, 1) + + group = qt.QGroupBox(self) + group.setLayout(qt.QVBoxLayout()) + group.layout().addWidget(self.__numpySelection) + group.setTitle("Axis selection") + self.__axisSelection = group + + self.layout().addWidget(self.__axisSelection) + + self.__currentAvailableViews = [] + self.__currentView = None + self.__data = None + self.__useAxisSelection = False + self.__userSelectedView = None + + self.__views = [] + self.__index = {} + """store stack index for each views""" + + self._initializeViews() + + def _initializeViews(self): + """Inisialize the available views""" + views = self.createDefaultViews(self.__stack) + self.__views = list(views) + self.setDisplayMode(self.EMPTY_MODE) + + def createDefaultViews(self, parent=None): + """Create and returns available views which can be displayed by default + by the data viewer. It is called internally by the widget. It can be + overwriten to provide a different set of viewers. + + :param QWidget parent: QWidget parent of the views + :rtype: list[silx.gui.data.DataViews.DataView] + """ + viewClasses = [ + DataViews._EmptyView, + DataViews._Hdf5View, + DataViews._NXdataView, + DataViews._Plot1dView, + DataViews._Plot2dView, + DataViews._Plot3dView, + DataViews._RawView, + DataViews._StackView, + ] + views = [] + for viewClass in viewClasses: + try: + view = viewClass(parent) + views.append(view) + except Exception: + _logger.warning("%s instantiation failed. View is ignored" % viewClass.__name__) + _logger.debug("Backtrace", exc_info=True) + + return views + + def clear(self): + """Clear the widget""" + self.setData(None) + + def normalizeData(self, data): + """Returns a normalized data if the embed a numpy or a dataset. + Else returns the data.""" + return _normalizeData(data) + + def __getStackIndex(self, view): + """Get the stack index containing the view. + + :param silx.gui.data.DataViews.DataView view: The view + """ + if view not in self.__index: + widget = view.getWidget() + index = self.__stack.addWidget(widget) + self.__index[view] = index + else: + index = self.__index[view] + return index + + def __clearCurrentView(self): + """Clear the current selected view""" + view = self.__currentView + if view is not None: + view.clear() + + def __numpyCustomAxisChanged(self, name, value): + view = self.__currentView + if view is not None: + view.setCustomAxisValue(name, value) + + def __updateNumpySelectionAxis(self): + """ + Update the numpy-selector according to the needed axis names + """ + previous = self.__numpySelection.blockSignals(True) + self.__numpySelection.clear() + info = DataViews.DataInfo(self.__data) + axisNames = self.__currentView.axesNames(self.__data, info) + if info.isArray and self.__data is not None and len(axisNames) > 0: + self.__useAxisSelection = True + self.__numpySelection.setAxisNames(axisNames) + self.__numpySelection.setCustomAxis(self.__currentView.customAxisNames()) + data = self.normalizeData(self.__data) + self.__numpySelection.setData(data) + if hasattr(data, "shape"): + isVisible = not (len(axisNames) == 1 and len(data.shape) == 1) + else: + isVisible = True + self.__axisSelection.setVisible(isVisible) + else: + self.__useAxisSelection = False + self.__axisSelection.setVisible(False) + self.__numpySelection.blockSignals(previous) + + def __updateDataInView(self): + """ + Update the views using the current data + """ + if self.__useAxisSelection: + self.__displayedData = self.__numpySelection.selectedData() + else: + self.__displayedData = self.__data + + qt.QTimer.singleShot(10, self.__setDataInView) + + def __setDataInView(self): + self.__currentView.setData(self.__displayedData) + + def setDisplayedView(self, view): + """Set the displayed view. + + Change the displayed view according to the view itself. + + :param silx.gui.data.DataViews.DataView view: The DataView to use to display the data + """ + self.__userSelectedView = view + self._setDisplayedView(view) + + def _setDisplayedView(self, view): + """Internal set of the displayed view. + + Change the displayed view according to the view itself. + + :param silx.gui.data.DataViews.DataView view: The DataView to use to display the data + """ + if self.__currentView is view: + return + self.__clearCurrentView() + self.__currentView = view + self.__updateNumpySelectionAxis() + self.__updateDataInView() + stackIndex = self.__getStackIndex(self.__currentView) + if self.__currentView is not None: + self.__currentView.select() + self.__stack.setCurrentIndex(stackIndex) + self.displayedViewChanged.emit(view) + + def getViewFromModeId(self, modeId): + """Returns the first available view which have the requested modeId. + + :param int modeId: Requested mode id + :rtype: silx.gui.data.DataViews.DataView + """ + for view in self.__views: + if view.modeId() == modeId: + return view + return view + + def setDisplayMode(self, modeId): + """Set the displayed view using display mode. + + Change the displayed view according to the requested mode. + + :param int modeId: Display mode, one of + + - `EMPTY_MODE`: display nothing + - `PLOT1D_MODE`: display the data as a curve + - `PLOT2D_MODE`: display the data as an image + - `PLOT3D_MODE`: display the data as an isosurface + - `RAW_MODE`: display the data as a table + - `STACK_MODE`: display the data as a stack of images + - `HDF5_MODE`: display the data as a table + """ + try: + view = self.getViewFromModeId(modeId) + except KeyError: + raise ValueError("Display mode %s is unknown" % modeId) + self._setDisplayedView(view) + + def displayedView(self): + """Returns the current displayed view. + + :rtype: silx.gui.data.DataViews.DataView + """ + return self.__currentView + + def addView(self, view): + """Allow to add a view to the dataview. + + If the current data support this view, it will be displayed. + + :param DataView view: A dataview + """ + self.__views.append(view) + # TODO It can be skipped if the view do not support the data + self.__updateAvailableViews() + + def removeView(self, view): + """Allow to remove a view which was available from the dataview. + + If the view was displayed, the widget will be updated. + + :param DataView view: A dataview + """ + self.__views.remove(view) + self.__stack.removeWidget(view.getWidget()) + # invalidate the full index. It will be updated as expected + self.__index = {} + + if self.__userSelectedView is view: + self.__userSelectedView = None + + if view is self.__currentView: + self.__updateView() + else: + # TODO It can be skipped if the view is not part of the + # available views + self.__updateAvailableViews() + + def __updateAvailableViews(self): + """ + Update available views from the current data. + """ + data = self.__data + # sort available views according to priority + info = DataViews.DataInfo(data) + priorities = [v.getDataPriority(data, info) for v in self.__views] + views = zip(priorities, self.__views) + views = filter(lambda t: t[0] > DataViews.DataView.UNSUPPORTED, views) + views = sorted(views, reverse=True) + + # store available views + if len(views) == 0: + self.__setCurrentAvailableViews([]) + available = [] + else: + available = [v[1] for v in views] + self.__setCurrentAvailableViews(available) + + def __updateView(self): + """Display the data using the widget which fit the best""" + data = self.__data + + # update available views for this data + self.__updateAvailableViews() + available = self.__currentAvailableViews + + # display the view with the most priority (the default view) + view = self.getDefaultViewFromAvailableViews(data, available) + self.__clearCurrentView() + try: + self._setDisplayedView(view) + except Exception as e: + # in case there is a problem to read the data, try to use a safe + # view + view = self.getSafeViewFromAvailableViews(data, available) + self._setDisplayedView(view) + raise e + + def getSafeViewFromAvailableViews(self, data, available): + """Returns a view which is sure to display something without failing + on rendering. + + :param object data: data which will be displayed + :param list[view] available: List of available views, from highest + priority to lowest. + :rtype: DataView + """ + hdf5View = self.getViewFromModeId(DataViewer.HDF5_MODE) + if hdf5View in available: + return hdf5View + return self.getViewFromModeId(DataViewer.EMPTY_MODE) + + def getDefaultViewFromAvailableViews(self, data, available): + """Returns the default view which will be used according to available + views. + + :param object data: data which will be displayed + :param list[view] available: List of available views, from highest + priority to lowest. + :rtype: DataView + """ + if len(available) > 0: + # returns the view with the highest priority + if self.__userSelectedView in available: + return self.__userSelectedView + self.__userSelectedView = None + view = available[0] + else: + # else returns the empty view + view = self.getViewFromModeId(DataViewer.EMPTY_MODE) + return view + + def __setCurrentAvailableViews(self, availableViews): + """Set the current available viewa + + :param List[DataView] availableViews: Current available viewa + """ + self.__currentAvailableViews = availableViews + self.currentAvailableViewsChanged.emit() + + def currentAvailableViews(self): + """Returns the list of available views for the current data + + :rtype: List[DataView] + """ + return self.__currentAvailableViews + + def availableViews(self): + """Returns the list of registered views + + :rtype: List[DataView] + """ + return self.__views + + def setData(self, data): + """Set the data to view. + + It mostly can be a h5py.Dataset or a numpy.ndarray. Other kind of + objects will be displayed as text rendering. + + :param numpy.ndarray data: The data. + """ + self.__data = data + self.__displayedData = None + self.__updateView() + self.__updateNumpySelectionAxis() + self.__updateDataInView() + self.dataChanged.emit() + + def __numpyAxisChanged(self): + """ + Called when axis selection of the numpy-selector changed + """ + self.__clearCurrentView() + + def __numpySelectionChanged(self): + """ + Called when data selection of the numpy-selector changed + """ + self.__updateDataInView() + + def data(self): + """Returns the data""" + return self.__data + + def displayMode(self): + """Returns the current display mode""" + return self.__currentView.modeId() diff --git a/silx/gui/data/DataViewerFrame.py b/silx/gui/data/DataViewerFrame.py new file mode 100644 index 0000000..b48fa7b --- /dev/null +++ b/silx/gui/data/DataViewerFrame.py @@ -0,0 +1,186 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module contains a DataViewer with a view selector. +""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "10/04/2017" + +from silx.gui import qt +from .DataViewer import DataViewer +from .DataViewerSelector import DataViewerSelector + + +class DataViewerFrame(qt.QWidget): + """ + A :class:`DataViewer` with a view selector. + + .. image:: img/DataViewerFrame.png + + This widget provides the same API as :class:`DataViewer`. Therefore, for more + documentation, take a look at the documentation of the class + :class:`DataViewer`. + + .. code-block:: python + + import numpy + data = numpy.random.rand(500,500) + viewer = DataViewerFrame() + viewer.setData(data) + viewer.setVisible(True) + + """ + + displayedViewChanged = qt.Signal(object) + """Emitted when the displayed view changes""" + + dataChanged = qt.Signal() + """Emitted when the data changes""" + + def __init__(self, parent=None): + """ + Constructor + + :param qt.QWidget parent: + """ + super(DataViewerFrame, self).__init__(parent) + + class _DataViewer(DataViewer): + """Overwrite methods to avoid to create views while the instance + is not created. `initializeViews` have to be called manually.""" + + def _initializeViews(self): + pass + + def initializeViews(self): + """Avoid to create views while the instance is not created.""" + super(_DataViewer, self)._initializeViews() + + self.__dataViewer = _DataViewer(self) + # initialize views when `self.__dataViewer` is set + self.__dataViewer.initializeViews() + self.__dataViewer.setFrameShape(qt.QFrame.StyledPanel) + self.__dataViewer.setFrameShadow(qt.QFrame.Sunken) + self.__dataViewerSelector = DataViewerSelector(self, self.__dataViewer) + self.__dataViewerSelector.setFlat(True) + + layout = qt.QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + layout.addWidget(self.__dataViewer, 1) + layout.addWidget(self.__dataViewerSelector) + self.setLayout(layout) + + self.__dataViewer.dataChanged.connect(self.__dataChanged) + self.__dataViewer.displayedViewChanged.connect(self.__displayedViewChanged) + + def __dataChanged(self): + """Called when the data is changed""" + self.dataChanged.emit() + + def __displayedViewChanged(self, view): + """Called when the displayed view changes""" + self.displayedViewChanged.emit(view) + + def availableViews(self): + """Returns the list of registered views + + :rtype: List[DataView] + """ + return self.__dataViewer.availableViews() + + def currentAvailableViews(self): + """Returns the list of available views for the current data + + :rtype: List[DataView] + """ + return self.__dataViewer.currentAvailableViews() + + def createDefaultViews(self, parent=None): + """Create and returns available views which can be displayed by default + by the data viewer. It is called internally by the widget. It can be + overwriten to provide a different set of viewers. + + :param QWidget parent: QWidget parent of the views + :rtype: list[silx.gui.data.DataViews.DataView] + """ + return self.__dataViewer.createDefaultViews(parent) + + def addView(self, view): + """Allow to add a view to the dataview. + + If the current data support this view, it will be displayed. + + :param DataView view: A dataview + """ + return self.__dataViewer.addView(view) + + def removeView(self, view): + """Allow to remove a view which was available from the dataview. + + If the view was displayed, the widget will be updated. + + :param DataView view: A dataview + """ + return self.__dataViewer.removeView(view) + + def setData(self, data): + """Set the data to view. + + It mostly can be a h5py.Dataset or a numpy.ndarray. Other kind of + objects will be displayed as text rendering. + + :param numpy.ndarray data: The data. + """ + self.__dataViewer.setData(data) + + def data(self): + """Returns the data""" + return self.__dataViewer.data() + + def setDisplayedView(self, view): + self.__dataViewer.setDisplayedView(view) + + def displayedView(self): + return self.__dataViewer.displayedView() + + def displayMode(self): + return self.__dataViewer.displayMode() + + def setDisplayMode(self, modeId): + """Set the displayed view using display mode. + + Change the displayed view according to the requested mode. + + :param int modeId: Display mode, one of + + - `EMPTY_MODE`: display nothing + - `PLOT1D_MODE`: display the data as a curve + - `PLOT2D_MODE`: display the data as an image + - `TEXT_MODE`: display the data as a text + - `ARRAY_MODE`: display the data as a table + """ + return self.__dataViewer.setDisplayMode(modeId) diff --git a/silx/gui/data/DataViewerSelector.py b/silx/gui/data/DataViewerSelector.py new file mode 100644 index 0000000..32cc636 --- /dev/null +++ b/silx/gui/data/DataViewerSelector.py @@ -0,0 +1,153 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module defines a widget to be able to select the available view +of the DataViewer. +""" +from __future__ import division + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "26/01/2017" + +import weakref +import functools +from silx.gui import qt +from silx.gui.data.DataViewer import DataViewer +import silx.utils.weakref + + +class DataViewerSelector(qt.QWidget): + """Widget to be able to select a custom view from the DataViewer""" + + def __init__(self, parent=None, dataViewer=None): + """Constructor + + :param QWidget parent: The parent of the widget + :param DataViewer dataViewer: The connected `DataViewer` + """ + super(DataViewerSelector, self).__init__(parent) + + self.__group = None + self.__buttons = {} + self.__buttonDummy = None + self.__dataViewer = None + + if dataViewer is not None: + self.setDataViewer(dataViewer) + + def __updateButtons(self): + if self.__group is not None: + self.__group.deleteLater() + self.__buttons = {} + self.__buttonDummy = None + + self.__group = qt.QButtonGroup(self) + self.setLayout(qt.QHBoxLayout()) + self.layout().setContentsMargins(0, 0, 0, 0) + if self.__dataViewer is None: + return + + iconSize = qt.QSize(16, 16) + + for view in self.__dataViewer.availableViews(): + label = view.label() + icon = view.icon() + button = qt.QPushButton(label) + button.setIcon(icon) + button.setIconSize(iconSize) + button.setCheckable(True) + # the weak objects are needed to be able to destroy the widget safely + weakView = weakref.ref(view) + weakMethod = silx.utils.weakref.WeakMethodProxy(self.__setDisplayedView) + callback = functools.partial(weakMethod, weakView) + button.clicked.connect(callback) + self.layout().addWidget(button) + self.__group.addButton(button) + self.__buttons[view] = button + + button = qt.QPushButton("Dummy") + button.setCheckable(True) + button.setVisible(False) + self.layout().addWidget(button) + self.__group.addButton(button) + self.__buttonDummy = button + + self.layout().addStretch(1) + + self.__updateButtonsVisibility() + self.__displayedViewChanged(self.__dataViewer.displayedView()) + + def setDataViewer(self, dataViewer): + """Define the dataviewer connected to this status bar + + :param DataViewer dataViewer: The connected `DataViewer` + """ + if self.__dataViewer is dataViewer: + return + if self.__dataViewer is not None: + self.__dataViewer.dataChanged.disconnect(self.__updateButtonsVisibility) + self.__dataViewer.displayedViewChanged.disconnect(self.__displayedViewChanged) + self.__dataViewer = dataViewer + if self.__dataViewer is not None: + self.__dataViewer.dataChanged.connect(self.__updateButtonsVisibility) + self.__dataViewer.displayedViewChanged.connect(self.__displayedViewChanged) + self.__updateButtons() + + def setFlat(self, isFlat): + """Set the flat state of all the buttons. + + :param bool isFlat: True to display the buttons flatten. + """ + for b in self.__buttons.values(): + b.setFlat(isFlat) + self.__buttonDummy.setFlat(isFlat) + + def __displayedViewChanged(self, view): + """Called on displayed view changeS""" + selectedButton = self.__buttons.get(view, self.__buttonDummy) + selectedButton.setChecked(True) + + def __setDisplayedView(self, refView, clickEvent=None): + """Display a data using the requested view + + :param DataView view: Requested view + :param clickEvent: Event sent by the clicked event + """ + if self.__dataViewer is None: + return + view = refView() + if view is None: + return + self.__dataViewer.setDisplayedView(view) + + def __updateButtonsVisibility(self): + """Called on data changed""" + if self.__dataViewer is None: + for b in self.__buttons.values(): + b.setVisible(False) + else: + availableViews = set(self.__dataViewer.currentAvailableViews()) + for view, button in self.__buttons.items(): + button.setVisible(view in availableViews) diff --git a/silx/gui/data/DataViews.py b/silx/gui/data/DataViews.py new file mode 100644 index 0000000..d8d605a --- /dev/null +++ b/silx/gui/data/DataViews.py @@ -0,0 +1,988 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module defines a views used by :class:`silx.gui.data.DataViewer`. +""" + +import logging +import numbers +import numpy + +import silx.io +from silx.gui import qt, icons +from silx.gui.data.TextFormatter import TextFormatter +from silx.io import nxdata +from silx.gui.hdf5 import H5Node +from silx.io.nxdata import NXdata + +__authors__ = ["V. Valls", "P. Knobel"] +__license__ = "MIT" +__date__ = "07/04/2017" + +_logger = logging.getLogger(__name__) + + +# DataViewer modes +EMPTY_MODE = 0 +PLOT1D_MODE = 10 +PLOT2D_MODE = 20 +PLOT3D_MODE = 30 +RAW_MODE = 40 +RAW_ARRAY_MODE = 41 +RAW_RECORD_MODE = 42 +RAW_SCALAR_MODE = 43 +STACK_MODE = 50 +HDF5_MODE = 60 + + +def _normalizeData(data): + """Returns a normalized data. + + If the data embed a numpy data or a dataset it is returned. + Else returns the input data.""" + if isinstance(data, H5Node): + return data.h5py_object + return data + + +def _normalizeComplex(data): + """Returns a normalized complex data. + + If the data is a numpy data with complex, returns the + absolute value. + Else returns the input data.""" + if hasattr(data, "dtype"): + isComplex = numpy.issubdtype(data.dtype, numpy.complex) + else: + isComplex = isinstance(data, numbers.Complex) + if isComplex: + data = numpy.absolute(data) + return data + + +class DataInfo(object): + """Store extracted information from a data""" + + def __init__(self, data): + data = self.normalizeData(data) + self.isArray = False + self.interpretation = None + self.isNumeric = False + self.isComplex = False + self.isRecord = False + self.isNXdata = False + self.shape = tuple() + self.dim = 0 + + if data is None: + return + + if silx.io.is_group(data) and nxdata.is_valid_nxdata(data): + self.isNXdata = True + nxd = nxdata.NXdata(data) + + if isinstance(data, numpy.ndarray): + self.isArray = True + elif silx.io.is_dataset(data) and data.shape != tuple(): + self.isArray = True + else: + self.isArray = False + + if silx.io.is_dataset(data): + self.interpretation = data.attrs.get("interpretation", None) + elif self.isNXdata: + self.interpretation = nxd.interpretation + else: + self.interpretation = None + + if hasattr(data, "dtype"): + self.isNumeric = numpy.issubdtype(data.dtype, numpy.number) + self.isRecord = data.dtype.fields is not None + self.isComplex = numpy.issubdtype(data.dtype, numpy.complex) + elif self.isNXdata: + self.isNumeric = numpy.issubdtype(nxd.signal.dtype, + numpy.number) + self.isComplex = numpy.issubdtype(nxd.signal.dtype, numpy.complex) + else: + self.isNumeric = isinstance(data, numbers.Number) + self.isComplex = isinstance(data, numbers.Complex) + self.isRecord = False + + if hasattr(data, "shape"): + self.shape = data.shape + elif self.isNXdata: + self.shape = nxd.signal.shape + else: + self.shape = tuple() + self.dim = len(self.shape) + + def normalizeData(self, data): + """Returns a normalized data if the embed a numpy or a dataset. + Else returns the data.""" + return _normalizeData(data) + + +class DataView(object): + """Holder for the data view.""" + + UNSUPPORTED = -1 + """Priority returned when the requested data can't be displayed by the + view.""" + + def __init__(self, parent, modeId=None, icon=None, label=None): + """Constructor + + :param qt.QWidget parent: Parent of the hold widget + """ + self.__parent = parent + self.__widget = None + self.__modeId = modeId + if label is None: + label = self.__class__.__name__ + self.__label = label + if icon is None: + icon = qt.QIcon() + self.__icon = icon + + def icon(self): + """Returns the default icon""" + return self.__icon + + def label(self): + """Returns the default label""" + return self.__label + + def modeId(self): + """Returns the mode id""" + return self.__modeId + + def normalizeData(self, data): + """Returns a normalized data if the embed a numpy or a dataset. + Else returns the data.""" + return _normalizeData(data) + + def customAxisNames(self): + """Returns names of axes which can be custom by the user and provided + to the view.""" + return [] + + def setCustomAxisValue(self, name, value): + """ + Set the value of a custom axis + + :param str name: Name of the custom axis + :param int value: Value of the custom axis + """ + pass + + def isWidgetInitialized(self): + """Returns true if the widget is already initialized. + """ + return self.__widget is not None + + def select(self): + """Called when the view is selected to display the data. + """ + return + + def getWidget(self): + """Returns the widget hold in the view and displaying the data. + + :returns: qt.QWidget + """ + if self.__widget is None: + self.__widget = self.createWidget(self.__parent) + return self.__widget + + def createWidget(self, parent): + """Create the the widget displaying the data + + :param qt.QWidget parent: Parent of the widget + :returns: qt.QWidget + """ + raise NotImplementedError() + + def clear(self): + """Clear the data from the view""" + return None + + def setData(self, data): + """Set the data displayed by the view + + :param data: Data to display + :type data: numpy.ndarray or h5py.Dataset + """ + return None + + def axesNames(self, data, info): + """Returns names of the expected axes of the view, according to the + input data. + + :param data: Data to display + :type data: numpy.ndarray or h5py.Dataset + :param DataInfo info: Pre-computed information on the data + :rtype: list[str] + """ + return [] + + def getDataPriority(self, data, info): + """ + Returns the priority of using this view according to a data. + + - `UNSUPPORTED` means this view can't display this data + - `1` means this view can display the data + - `100` means this view should be used for this data + - `1000` max value used by the views provided by silx + - ... + + :param object data: The data to check + :param DataInfo info: Pre-computed information on the data + :rtype: int + """ + return DataView.UNSUPPORTED + + def __lt__(self, other): + return str(self) < str(other) + + +class CompositeDataView(DataView): + """Data view which can display a data using different view according to + the kind of the data.""" + + def __init__(self, parent, modeId=None, icon=None, label=None): + """Constructor + + :param qt.QWidget parent: Parent of the hold widget + """ + super(CompositeDataView, self).__init__(parent, modeId, icon, label) + self.__views = {} + self.__currentView = None + + def addView(self, dataView): + """Add a new dataview to the available list.""" + self.__views[dataView] = None + + def getBestView(self, data, info): + """Returns the best view according to priorities.""" + info = DataInfo(data) + views = [(v.getDataPriority(data, info), v) for v in self.__views.keys()] + views = filter(lambda t: t[0] > DataView.UNSUPPORTED, views) + views = sorted(views, reverse=True) + + if len(views) == 0: + return None + elif views[0][0] == DataView.UNSUPPORTED: + return None + else: + return views[0][1] + + def customAxisNames(self): + if self.__currentView is None: + return + return self.__currentView.customAxisNames() + + def setCustomAxisValue(self, name, value): + if self.__currentView is None: + return + self.__currentView.setCustomAxisValue(name, value) + + def __updateDisplayedView(self): + widget = self.getWidget() + if self.__currentView is None: + return + + # load the widget if it is not yet done + index = self.__views[self.__currentView] + if index is None: + w = self.__currentView.getWidget() + index = widget.addWidget(w) + self.__views[self.__currentView] = index + if widget.currentIndex() != index: + widget.setCurrentIndex(index) + self.__currentView.select() + + def select(self): + self.__updateDisplayedView() + if self.__currentView is not None: + self.__currentView.select() + + def createWidget(self, parent): + return qt.QStackedWidget() + + def clear(self): + for v in self.__views.keys(): + v.clear() + + def setData(self, data): + if self.__currentView is None: + return + self.__updateDisplayedView() + self.__currentView.setData(data) + + def axesNames(self, data, info): + view = self.getBestView(data, info) + self.__currentView = view + return view.axesNames(data, info) + + def getDataPriority(self, data, info): + view = self.getBestView(data, info) + self.__currentView = view + if view is None: + return DataView.UNSUPPORTED + else: + return view.getDataPriority(data, info) + + +class _EmptyView(DataView): + """Dummy view to display nothing""" + + def __init__(self, parent): + DataView.__init__(self, parent, modeId=EMPTY_MODE) + + def axesNames(self, data, info): + return [] + + def createWidget(self, parent): + return qt.QLabel(parent) + + def getDataPriority(self, data, info): + return DataView.UNSUPPORTED + + +class _Plot1dView(DataView): + """View displaying data using a 1d plot""" + + def __init__(self, parent): + super(_Plot1dView, self).__init__( + parent=parent, + modeId=PLOT1D_MODE, + label="Curve", + icon=icons.getQIcon("view-1d")) + self.__resetZoomNextTime = True + + def createWidget(self, parent): + from silx.gui import plot + return plot.Plot1D(parent=parent) + + def clear(self): + self.getWidget().clear() + self.__resetZoomNextTime = True + + def normalizeData(self, data): + data = DataView.normalizeData(self, data) + data = _normalizeComplex(data) + return data + + def setData(self, data): + data = self.normalizeData(data) + self.getWidget().addCurve(legend="data", + x=range(len(data)), + y=data, + resetzoom=self.__resetZoomNextTime) + self.__resetZoomNextTime = True + + def axesNames(self, data, info): + return ["y"] + + def getDataPriority(self, data, info): + if data is None or not info.isArray or not info.isNumeric: + return DataView.UNSUPPORTED + if info.dim < 1: + return DataView.UNSUPPORTED + if info.interpretation == "spectrum": + return 1000 + if info.dim == 2 and info.shape[0] == 1: + return 210 + if info.dim == 1: + return 100 + else: + return 10 + + +class _Plot2dView(DataView): + """View displaying data using a 2d plot""" + + def __init__(self, parent): + super(_Plot2dView, self).__init__( + parent=parent, + modeId=PLOT2D_MODE, + label="Image", + icon=icons.getQIcon("view-2d")) + self.__resetZoomNextTime = True + + def createWidget(self, parent): + from silx.gui import plot + widget = plot.Plot2D(parent=parent) + widget.setKeepDataAspectRatio(True) + widget.setGraphXLabel('X') + widget.setGraphYLabel('Y') + return widget + + def clear(self): + self.getWidget().clear() + self.__resetZoomNextTime = True + + def normalizeData(self, data): + data = DataView.normalizeData(self, data) + data = _normalizeComplex(data) + return data + + def setData(self, data): + data = self.normalizeData(data) + self.getWidget().addImage(legend="data", + data=data, + resetzoom=self.__resetZoomNextTime) + self.__resetZoomNextTime = False + + def axesNames(self, data, info): + return ["y", "x"] + + def getDataPriority(self, data, info): + if data is None or not info.isArray or not info.isNumeric: + return DataView.UNSUPPORTED + if info.dim < 2: + return DataView.UNSUPPORTED + if info.interpretation == "image": + return 1000 + if info.dim == 2: + return 200 + else: + return 190 + + +class _Plot3dView(DataView): + """View displaying data using a 3d plot""" + + def __init__(self, parent): + super(_Plot3dView, self).__init__( + parent=parent, + modeId=PLOT3D_MODE, + label="Cube", + icon=icons.getQIcon("view-3d")) + try: + import silx.gui.plot3d #noqa + except ImportError: + _logger.warning("Plot3dView is not available") + _logger.debug("Backtrace", exc_info=True) + raise + self.__resetZoomNextTime = True + + def createWidget(self, parent): + from silx.gui.plot3d import ScalarFieldView + from silx.gui.plot3d import SFViewParamTree + + plot = ScalarFieldView.ScalarFieldView(parent) + plot.setAxesLabels(*reversed(self.axesNames(None, None))) + plot.addIsosurface( + lambda data: numpy.mean(data) + numpy.std(data), '#FF0000FF') + + # Create a parameter tree for the scalar field view + options = SFViewParamTree.TreeView(plot) + options.setSfView(plot) + + # Add the parameter tree to the main window in a dock widget + dock = qt.QDockWidget() + dock.setWidget(options) + plot.addDockWidget(qt.Qt.RightDockWidgetArea, dock) + + return plot + + def clear(self): + self.getWidget().setData(None) + self.__resetZoomNextTime = True + + def normalizeData(self, data): + data = DataView.normalizeData(self, data) + data = _normalizeComplex(data) + return data + + def setData(self, data): + data = self.normalizeData(data) + plot = self.getWidget() + plot.setData(data) + self.__resetZoomNextTime = False + + def axesNames(self, data, info): + return ["z", "y", "x"] + + def getDataPriority(self, data, info): + if data is None or not info.isArray or not info.isNumeric: + return DataView.UNSUPPORTED + if info.dim < 3: + return DataView.UNSUPPORTED + if min(data.shape) < 2: + return DataView.UNSUPPORTED + if info.dim == 3: + return 100 + else: + return 10 + + +class _ArrayView(DataView): + """View displaying data using a 2d table""" + + def __init__(self, parent): + DataView.__init__(self, parent, modeId=RAW_ARRAY_MODE) + + def createWidget(self, parent): + from silx.gui.data.ArrayTableWidget import ArrayTableWidget + widget = ArrayTableWidget(parent) + widget.displayAxesSelector(False) + return widget + + def clear(self): + self.getWidget().setArrayData(numpy.array([[]])) + + def setData(self, data): + data = self.normalizeData(data) + self.getWidget().setArrayData(data) + + def axesNames(self, data, info): + return ["col", "row"] + + def getDataPriority(self, data, info): + if data is None or not info.isArray or info.isRecord: + return DataView.UNSUPPORTED + if info.dim < 2: + return DataView.UNSUPPORTED + if info.interpretation in ["scalar", "scaler"]: + return 1000 + return 500 + + +class _StackView(DataView): + """View displaying data using a stack of images""" + + def __init__(self, parent): + super(_StackView, self).__init__( + parent=parent, + modeId=STACK_MODE, + label="Image stack", + icon=icons.getQIcon("view-2d-stack")) + self.__resetZoomNextTime = True + + def customAxisNames(self): + return ["depth"] + + def setCustomAxisValue(self, name, value): + if name == "depth": + self.getWidget().setFrameNumber(value) + else: + raise Exception("Unsupported axis") + + def createWidget(self, parent): + from silx.gui import plot + widget = plot.StackView(parent=parent) + widget.setKeepDataAspectRatio(True) + widget.setLabels(self.axesNames(None, None)) + # hide default option panel + widget.setOptionVisible(False) + return widget + + def clear(self): + self.getWidget().clear() + self.__resetZoomNextTime = True + + def normalizeData(self, data): + data = DataView.normalizeData(self, data) + data = _normalizeComplex(data) + return data + + def setData(self, data): + data = self.normalizeData(data) + self.getWidget().setStack(stack=data, reset=self.__resetZoomNextTime) + self.__resetZoomNextTime = False + + def axesNames(self, data, info): + return ["depth", "y", "x"] + + def getDataPriority(self, data, info): + if data is None or not info.isArray or not info.isNumeric: + return DataView.UNSUPPORTED + if info.dim < 3: + return DataView.UNSUPPORTED + if info.interpretation == "image": + return 500 + return 90 + + +class _ScalarView(DataView): + """View displaying data using text""" + + def __init__(self, parent): + DataView.__init__(self, parent, modeId=RAW_SCALAR_MODE) + + def createWidget(self, parent): + widget = qt.QTextEdit(parent) + widget.setTextInteractionFlags(qt.Qt.TextSelectableByMouse) + widget.setAlignment(qt.Qt.AlignLeft | qt.Qt.AlignTop) + self.__formatter = TextFormatter(parent) + return widget + + def clear(self): + self.getWidget().setText("") + + def setData(self, data): + data = self.normalizeData(data) + if silx.io.is_dataset(data): + data = data[()] + text = self.__formatter.toString(data) + self.getWidget().setText(text) + + def axesNames(self, data, info): + return [] + + def getDataPriority(self, data, info): + data = self.normalizeData(data) + if data is None: + return DataView.UNSUPPORTED + if silx.io.is_group(data): + return DataView.UNSUPPORTED + return 2 + + +class _RecordView(DataView): + """View displaying data using text""" + + def __init__(self, parent): + DataView.__init__(self, parent, modeId=RAW_RECORD_MODE) + + def createWidget(self, parent): + from .RecordTableView import RecordTableView + widget = RecordTableView(parent) + widget.setWordWrap(False) + return widget + + def clear(self): + self.getWidget().setArrayData(None) + + def setData(self, data): + data = self.normalizeData(data) + widget = self.getWidget() + widget.setArrayData(data) + widget.resizeRowsToContents() + widget.resizeColumnsToContents() + + def axesNames(self, data, info): + return ["data"] + + def getDataPriority(self, data, info): + if info.isRecord: + return 40 + if data is None or not info.isArray: + return DataView.UNSUPPORTED + if info.dim == 1: + if info.interpretation in ["scalar", "scaler"]: + return 1000 + if info.shape[0] == 1: + return 510 + return 500 + elif info.isRecord: + return 40 + return DataView.UNSUPPORTED + + +class _Hdf5View(DataView): + """View displaying data using text""" + + def __init__(self, parent): + super(_Hdf5View, self).__init__( + parent=parent, + modeId=HDF5_MODE, + label="HDF5", + icon=icons.getQIcon("view-hdf5")) + + def createWidget(self, parent): + from .Hdf5TableView import Hdf5TableView + widget = Hdf5TableView(parent) + return widget + + def clear(self): + widget = self.getWidget() + widget.setData(None) + + def setData(self, data): + widget = self.getWidget() + widget.setData(data) + + def axesNames(self, data, info): + return [] + + def getDataPriority(self, data, info): + widget = self.getWidget() + if widget.isSupportedData(data): + return 1 + else: + return DataView.UNSUPPORTED + + +class _RawView(CompositeDataView): + """View displaying data as raw data. + + This implementation use a 2d-array view, or a record array view, or a + raw text output. + """ + + def __init__(self, parent): + super(_RawView, self).__init__( + parent=parent, + modeId=RAW_MODE, + label="Raw", + icon=icons.getQIcon("view-raw")) + self.addView(_ScalarView(parent)) + self.addView(_ArrayView(parent)) + self.addView(_RecordView(parent)) + + +class _NXdataScalarView(DataView): + """DataView using a table view for displaying NXdata scalars: + 0-D signal or n-D signal with *@interpretation=scalar*""" + def __init__(self, parent): + DataView.__init__(self, parent) + + def createWidget(self, parent): + from silx.gui.data.ArrayTableWidget import ArrayTableWidget + widget = ArrayTableWidget(parent) + # widget.displayAxesSelector(False) + return widget + + def axesNames(self, data, info): + return ["col", "row"] + + def clear(self): + self.getWidget().setArrayData(numpy.array([[]]), + labels=True) + + def setData(self, data): + data = self.normalizeData(data) + signal = NXdata(data).signal + self.getWidget().setArrayData(signal, + labels=True) + + def getDataPriority(self, data, info): + data = self.normalizeData(data) + if info.isNXdata: + nxd = NXdata(data) + if nxd.signal_is_0d or nxd.interpretation in ["scalar", "scaler"]: + return 100 + return DataView.UNSUPPORTED + + +class _NXdataCurveView(DataView): + """DataView using a Plot1D for displaying NXdata curves: + 1-D signal or n-D signal with *@interpretation=spectrum*. + + It also handles basic scatter plots: + a 1-D signal with one axis whose values are not monotonically increasing. + """ + def __init__(self, parent): + DataView.__init__(self, parent) + + def createWidget(self, parent): + from silx.gui.data.NXdataWidgets import ArrayCurvePlot + widget = ArrayCurvePlot(parent) + return widget + + def axesNames(self, data, info): + # disabled (used by default axis selector widget in Hdf5Viewer) + return [] + + def clear(self): + self.getWidget().clear() + + def setData(self, data): + data = self.normalizeData(data) + nxd = NXdata(data) + signal_name = data.attrs["signal"] + group_name = data.name + if nxd.axes_names[-1] is not None: + x_errors = nxd.get_axis_errors(nxd.axes_names[-1]) + else: + x_errors = None + + self.getWidget().setCurveData(nxd.signal, nxd.axes[-1], + yerror=nxd.errors, xerror=x_errors, + ylabel=signal_name, xlabel=nxd.axes_names[-1], + title="NXdata group " + group_name) + + def getDataPriority(self, data, info): + data = self.normalizeData(data) + if info.isNXdata: + nxd = NXdata(data) + if nxd.is_x_y_value_scatter or nxd.is_unsupported_scatter: + return DataView.UNSUPPORTED + if nxd.signal_is_1d and \ + not nxd.interpretation in ["scalar", "scaler"]: + return 100 + if nxd.interpretation == "spectrum": + return 100 + return DataView.UNSUPPORTED + + +class _NXdataXYVScatterView(DataView): + """DataView using a Plot1D for displaying NXdata 3D scatters as + a scatter of coloured points (1-D signal with 2 axes)""" + def __init__(self, parent): + DataView.__init__(self, parent) + + def createWidget(self, parent): + from silx.gui.data.NXdataWidgets import ArrayCurvePlot + widget = ArrayCurvePlot(parent) + return widget + + def axesNames(self, data, info): + # disabled (used by default axis selector widget in Hdf5Viewer) + return [] + + def clear(self): + self.getWidget().clear() + + def setData(self, data): + data = self.normalizeData(data) + nxd = NXdata(data) + signal_name = data.attrs["signal"] + # signal_errors = nx.errors # not supported + group_name = data.name + x_axis, y_axis = nxd.axes[-2:] + + x_label, y_label = nxd.axes_names[-2:] + if x_label is not None: + x_errors = nxd.get_axis_errors(x_label) + else: + x_errors = None + + if y_label is not None: + y_errors = nxd.get_axis_errors(y_label) + else: + y_errors = None + + self.getWidget().setCurveData(y_axis, x_axis, values=nxd.signal, + yerror=y_errors, xerror=x_errors, + ylabel=signal_name, xlabel=x_label, + title="NXdata group " + group_name) + + def getDataPriority(self, data, info): + data = self.normalizeData(data) + if info.isNXdata: + if NXdata(data).is_x_y_value_scatter: + return 100 + return DataView.UNSUPPORTED + + +class _NXdataImageView(DataView): + """DataView using a Plot2D for displaying NXdata images: + 2-D signal or n-D signals with *@interpretation=spectrum*.""" + def __init__(self, parent): + DataView.__init__(self, parent) + + def createWidget(self, parent): + from silx.gui.data.NXdataWidgets import ArrayImagePlot + widget = ArrayImagePlot(parent) + return widget + + def axesNames(self, data, info): + return [] + + def clear(self): + self.getWidget().clear() + + def setData(self, data): + data = self.normalizeData(data) + nxd = NXdata(data) + signal_name = data.attrs["signal"] + group_name = data.name + y_axis, x_axis = nxd.axes[-2:] + y_label, x_label = nxd.axes_names[-2:] + + self.getWidget().setImageData( + nxd.signal, x_axis=x_axis, y_axis=y_axis, + signal_name=signal_name, xlabel=x_label, ylabel=y_label, + title="NXdata group %s: %s" % (group_name, signal_name)) + + def getDataPriority(self, data, info): + data = self.normalizeData(data) + if info.isNXdata: + nxd = NXdata(data) + if nxd.signal_is_2d: + if nxd.interpretation not in ["scalar", "spectrum", "scaler"]: + return 100 + if nxd.interpretation == "image": + return 100 + return DataView.UNSUPPORTED + + +class _NXdataStackView(DataView): + def __init__(self, parent): + DataView.__init__(self, parent) + + def createWidget(self, parent): + from silx.gui.data.NXdataWidgets import ArrayStackPlot + widget = ArrayStackPlot(parent) + return widget + + def axesNames(self, data, info): + return [] + + def clear(self): + self.getWidget().clear() + + def setData(self, data): + data = self.normalizeData(data) + nxd = NXdata(data) + signal_name = data.attrs["signal"] + group_name = data.name + z_axis, y_axis, x_axis = nxd.axes[-3:] + z_label, y_label, x_label = nxd.axes_names[-3:] + + self.getWidget().setStackData( + nxd.signal, x_axis=x_axis, y_axis=y_axis, z_axis=z_axis, + signal_name=signal_name, + xlabel=x_label, ylabel=y_label, zlabel=z_label, + title="NXdata group %s: %s" % (group_name, signal_name)) + + def getDataPriority(self, data, info): + data = self.normalizeData(data) + if info.isNXdata: + nxd = NXdata(data) + if nxd.signal_ndim >= 3: + if nxd.interpretation not in ["scalar", "scaler", + "spectrum", "image"]: + return 100 + return DataView.UNSUPPORTED + + +class _NXdataView(CompositeDataView): + """Composite view displaying NXdata groups using the most adequate + widget depending on the dimensionality.""" + def __init__(self, parent): + super(_NXdataView, self).__init__( + parent=parent, + label="NXdata", + icon=icons.getQIcon("view-nexus")) + + self.addView(_NXdataScalarView(parent)) + self.addView(_NXdataCurveView(parent)) + self.addView(_NXdataXYVScatterView(parent)) + self.addView(_NXdataImageView(parent)) + self.addView(_NXdataStackView(parent)) diff --git a/silx/gui/data/Hdf5TableView.py b/silx/gui/data/Hdf5TableView.py new file mode 100644 index 0000000..5d79907 --- /dev/null +++ b/silx/gui/data/Hdf5TableView.py @@ -0,0 +1,414 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +This module define model and widget to display 1D slices from numpy +array using compound data types or hdf5 databases. +""" +from __future__ import division + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "07/04/2017" + +import functools +import os.path +import logging +from silx.gui import qt +import silx.io +from .TextFormatter import TextFormatter +import silx.gui.hdf5 +from silx.gui.widgets import HierarchicalTableView + +_logger = logging.getLogger(__name__) + + +class _CellData(object): + """Store a table item + """ + def __init__(self, value=None, isHeader=False, span=None): + """ + Constructor + + :param str value: Label of this property + :param bool isHeader: True if the cell is an header + :param tuple span: Tuple of row, column span + """ + self.__value = value + self.__isHeader = isHeader + self.__span = span + + def isHeader(self): + """Returns true if the property is a sub-header title. + + :rtype: bool + """ + return self.__isHeader + + def value(self): + """Returns the value of the item. + """ + return self.__value + + def span(self): + """Returns the span size of the cell. + + :rtype: tuple + """ + return self.__span + + +class _TableData(object): + """Modelize a table with header, row and column span. + + It is mostly defined as a row based table. + """ + + def __init__(self, columnCount): + """Constructor. + + :param int columnCount: Define the number of column of the table + """ + self.__colCount = columnCount + self.__data = [] + + def rowCount(self): + """Returns the number of rows. + + :rtype: int + """ + return len(self.__data) + + def columnCount(self): + """Returns the number of columns. + + :rtype: int + """ + return self.__colCount + + def clear(self): + """Remove all the cells of the table""" + self.__data = [] + + def cellAt(self, row, column): + """Returns the cell at the row column location. Else None if there is + nothing. + + :rtype: _CellData + """ + if row < 0: + return None + if column < 0: + return None + if row >= len(self.__data): + return None + cells = self.__data[row] + if column >= len(cells): + return None + return cells[column] + + def addHeaderRow(self, headerLabel): + """Append the table with header on the full row. + + :param str headerLabel: label of the header. + """ + item = _CellData(value=headerLabel, isHeader=True, span=(1, self.__colCount)) + self.__data.append([item]) + + def addHeaderValueRow(self, headerLabel, value): + """Append the table with a row using the first column as an header and + other cells as a single cell for the value. + + :param str headerLabel: label of the header. + :param object value: value to store. + """ + header = _CellData(value=headerLabel, isHeader=True) + value = _CellData(value=value, span=(1, self.__colCount)) + self.__data.append([header, value]) + + def addRow(self, *args): + """Append the table with a row using arguments for each cells + + :param list[object] args: List of cell values for the row + """ + row = [] + for value in args: + if not isinstance(value, _CellData): + value = _CellData(value=value) + row.append(value) + self.__data.append(row) + + +class Hdf5TableModel(HierarchicalTableView.HierarchicalTableModel): + """This data model provides access to HDF5 node content (File, Group, + Dataset). Main info, like name, file, attributes... are displayed + """ + + def __init__(self, parent=None, data=None): + """ + Constructor + + :param qt.QObject parent: Parent object + :param object data: An h5py-like object (file, group or dataset) + """ + super(Hdf5TableModel, self).__init__(parent) + + self.__obj = None + self.__data = _TableData(columnCount=4) + self.__formatter = None + formatter = TextFormatter(self) + self.setFormatter(formatter) + self.setObject(data) + + def rowCount(self, parent_idx=None): + """Returns number of rows to be displayed in table""" + return self.__data.rowCount() + + def columnCount(self, parent_idx=None): + """Returns number of columns to be displayed in table""" + return self.__data.columnCount() + + def data(self, index, role=qt.Qt.DisplayRole): + """QAbstractTableModel method to access data values + in the format ready to be displayed""" + if not index.isValid(): + return None + + cell = self.__data.cellAt(index.row(), index.column()) + if cell is None: + return None + + if role == self.SpanRole: + return cell.span() + elif role == self.IsHeaderRole: + return cell.isHeader() + elif role == qt.Qt.DisplayRole: + value = cell.value() + if callable(value): + value = value(self.__obj) + return str(value) + return None + + def flags(self, index): + """QAbstractTableModel method to inform the view whether data + is editable or not. + """ + return qt.QAbstractTableModel.flags(self, index) + + def isSupportedObject(self, h5pyObject): + """ + Returns true if the provided object can be modelized using this model. + """ + isSupported = False + isSupported = isSupported or silx.io.is_group(h5pyObject) + isSupported = isSupported or silx.io.is_dataset(h5pyObject) + isSupported = isSupported or isinstance(h5pyObject, silx.gui.hdf5.H5Node) + return isSupported + + def setObject(self, h5pyObject): + """Set the h5py-like object exposed by the model + + :param h5pyObject: A h5py-like object. It can be a `h5py.Dataset`, + a `h5py.File`, a `h5py.Group`. It also can be a, + `silx.gui.hdf5.H5Node` which is needed to display some local path + information. + """ + if qt.qVersion() > "4.6": + self.beginResetModel() + + if h5pyObject is None or self.isSupportedObject(h5pyObject): + self.__obj = h5pyObject + else: + _logger.warning("Object class %s unsupported. Object ignored.", type(h5pyObject)) + self.__initProperties() + + if qt.qVersion() > "4.6": + self.endResetModel() + else: + self.reset() + + def __initProperties(self): + """Initialize the list of available properties according to the defined + h5py-like object.""" + self.__data.clear() + if self.__obj is None: + return + + obj = self.__obj + + hdf5obj = obj + if isinstance(obj, silx.gui.hdf5.H5Node): + hdf5obj = obj.h5py_object + + if silx.io.is_file(hdf5obj): + objectType = "File" + elif silx.io.is_group(hdf5obj): + objectType = "Group" + elif silx.io.is_dataset(hdf5obj): + objectType = "Dataset" + else: + objectType = obj.__class__.__name__ + self.__data.addHeaderRow(headerLabel="HDF5 %s" % objectType) + self.__data.addHeaderRow(headerLabel="Path info") + + self.__data.addHeaderValueRow("basename", lambda x: os.path.basename(x.name)) + self.__data.addHeaderValueRow("name", lambda x: x.name) + if silx.io.is_file(obj): + self.__data.addHeaderValueRow("filename", lambda x: x.filename) + + if isinstance(obj, silx.gui.hdf5.H5Node): + # helpful informations if the object come from an HDF5 tree + self.__data.addHeaderValueRow("local_basename", lambda x: x.local_basename) + self.__data.addHeaderValueRow("local_name", lambda x: x.local_name) + self.__data.addHeaderValueRow("local_filename", lambda x: x.local_file.filename) + + if hasattr(obj, "dtype"): + self.__data.addHeaderRow(headerLabel="Data info") + self.__data.addHeaderValueRow("dtype", lambda x: x.dtype) + if hasattr(obj, "shape"): + self.__data.addHeaderValueRow("shape", lambda x: x.shape) + if hasattr(obj, "size"): + self.__data.addHeaderValueRow("size", lambda x: x.size) + if hasattr(obj, "chunks") and obj.chunks is not None: + self.__data.addHeaderValueRow("chunks", lambda x: x.chunks) + + # relative to compression + # h5py expose compression, compression_opts but are not initialized + # for external plugins, then we use id + # h5py also expose fletcher32 and shuffle attributes, but it is also + # part of the filters + if hasattr(obj, "shape") and hasattr(obj, "id"): + dcpl = obj.id.get_create_plist() + if dcpl.get_nfilters() > 0: + self.__data.addHeaderRow(headerLabel="Compression info") + pos = _CellData(value="Position", isHeader=True) + hdf5id = _CellData(value="HDF5 ID", isHeader=True) + name = _CellData(value="Name", isHeader=True) + options = _CellData(value="Options", isHeader=True) + self.__data.addRow(pos, hdf5id, name, options) + for index in range(dcpl.get_nfilters()): + callback = lambda index, dataIndex, x: self.__get_filter_info(x, index)[dataIndex] + pos = _CellData(value=functools.partial(callback, index, 0)) + hdf5id = _CellData(value=functools.partial(callback, index, 1)) + name = _CellData(value=functools.partial(callback, index, 2)) + options = _CellData(value=functools.partial(callback, index, 3)) + self.__data.addRow(pos, hdf5id, name, options) + + if hasattr(obj, "attrs"): + if len(obj.attrs) > 0: + self.__data.addHeaderRow(headerLabel="Attributes") + for key in sorted(obj.attrs.keys()): + callback = lambda key, x: self.__formatter.toString(x.attrs[key]) + self.__data.addHeaderValueRow(headerLabel=key, value=functools.partial(callback, key)) + + def __get_filter_info(self, dataset, filterIndex): + """Get a tuple of readable info from dataset filters + + :param h5py.Dataset dataset: A h5py dataset + :param int filterId: + """ + try: + dcpl = dataset.id.get_create_plist() + info = dcpl.get_filter(filterIndex) + filterId, _flags, cdValues, name = info + name = self.__formatter.toString(name) + options = " ".join([self.__formatter.toString(i) for i in cdValues]) + return (filterIndex, filterId, name, options) + except Exception: + _logger.debug("Backtrace", exc_info=True) + return [filterIndex, None, None, None] + + def object(self): + """Returns the internal object modelized. + + :rtype: An h5py-like object + """ + return self.__obj + + def setFormatter(self, formatter): + """Set the formatter object to be used to display data from the model + + :param TextFormatter formatter: Formatter to use + """ + if formatter is self.__formatter: + return + + if qt.qVersion() > "4.6": + self.beginResetModel() + + if self.__formatter is not None: + self.__formatter.formatChanged.disconnect(self.__formatChanged) + + self.__formatter = formatter + if self.__formatter is not None: + self.__formatter.formatChanged.connect(self.__formatChanged) + + if qt.qVersion() > "4.6": + self.endResetModel() + else: + self.reset() + + def getFormatter(self): + """Returns the text formatter used. + + :rtype: TextFormatter + """ + return self.__formatter + + def __formatChanged(self): + """Called when the format changed. + """ + self.reset() + + +class Hdf5TableView(HierarchicalTableView.HierarchicalTableView): + """A widget to display metadata about a HDF5 node using a table.""" + + def __init__(self, parent=None): + super(Hdf5TableView, self).__init__(parent) + self.setModel(Hdf5TableModel(self)) + + def isSupportedData(self, data): + """ + Returns true if the provided object can be modelized using this model. + """ + return self.model().isSupportedObject(data) + + def setData(self, data): + """Set the h5py-like object exposed by the model + + :param h5pyObject: A h5py-like object. It can be a `h5py.Dataset`, + a `h5py.File`, a `h5py.Group`. It also can be a, + `silx.gui.hdf5.H5Node` which is needed to display some local path + information. + """ + self.model().setObject(data) + header = self.horizontalHeader() + if qt.qVersion() < "5.0": + setResizeMode = header.setResizeMode + else: + setResizeMode = header.setSectionResizeMode + setResizeMode(0, qt.QHeaderView.Fixed) + setResizeMode(1, qt.QHeaderView.Stretch) + header.setStretchLastSection(True) diff --git a/silx/gui/data/NXdataWidgets.py b/silx/gui/data/NXdataWidgets.py new file mode 100644 index 0000000..343c7f9 --- /dev/null +++ b/silx/gui/data/NXdataWidgets.py @@ -0,0 +1,523 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module defines widgets used by _NXdataView. +""" +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "20/03/2017" + +import numpy + +from silx.gui import qt +from silx.gui.data.NumpyAxesSelector import NumpyAxesSelector +from silx.gui.plot import Plot1D, Plot2D, StackView + +from silx.math.calibration import ArrayCalibration, NoCalibration, LinearCalibration + + +class ArrayCurvePlot(qt.QWidget): + """ + Widget for plotting a curve from a multi-dimensional signal array + and a 1D axis array. + + The signal array can have an arbitrary number of dimensions, the only + limitation being that the last dimension must have the same length as + the axis array. + + The widget provides sliders to select indices on the first (n - 1) + dimensions of the signal array, and buttons to add/replace selected + curves to the plot. + + This widget also handles simple 2D or 3D scatter plots (third dimension + displayed as colour of points). + """ + def __init__(self, parent=None): + """ + + :param parent: Parent QWidget + """ + super(ArrayCurvePlot, self).__init__(parent) + + self.__signal = None + self.__signal_name = None + self.__signal_errors = None + self.__axis = None + self.__axis_name = None + self.__axis_errors = None + self.__values = None + + self.__first_curve_added = False + + self._plot = Plot1D(self) + self._plot.setDefaultColormap( # for scatters + {"name": "viridis", + "vmin": 0., "vmax": 1., # ignored (autoscale) but mandatory + "normalization": "linear", + "autoscale": True}) + + self.selectorDock = qt.QDockWidget("Data selector", self._plot) + # not closable + self.selectorDock.setFeatures(qt.QDockWidget.DockWidgetMovable | + qt.QDockWidget.DockWidgetFloatable) + self._selector = NumpyAxesSelector(self.selectorDock) + self._selector.setNamedAxesSelectorVisibility(False) + self.__selector_is_connected = False + self.selectorDock.setWidget(self._selector) + self._plot.addTabbedDockWidget(self.selectorDock) + + layout = qt.QGridLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self._plot, 0, 0) + + self.setLayout(layout) + + def setCurveData(self, y, x=None, values=None, + yerror=None, xerror=None, + ylabel=None, xlabel=None, title=None): + """ + + :param y: dataset to be represented by the y (vertical) axis. + For a scatter, this must be a 1D array and x and values must be + 1-D arrays of the same size. + In other cases, it can be a n-D array whose last dimension must + have the same length as x (and values must be None) + :param x: 1-D dataset used as the curve's x values. If provided, + its lengths must be equal to the length of the last dimension of + ``y`` (and equal to the length of ``value``, for a scatter plot). + :param values: Values, to be provided for a x-y-value scatter plot. + This will be used to compute the color map and assign colors + to the points. + :param yerror: 1-D dataset of errors for y, or None + :param xerror: 1-D dataset of errors for x, or None + :param ylabel: Label for Y axis + :param xlabel: Label for X axis + :param title: Graph title + """ + self.__signal = y + self.__signal_name = ylabel + self.__signal_errors = yerror + self.__axis = x + self.__axis_name = xlabel + self.__axis_errors = xerror + self.__values = values + + if self.__selector_is_connected: + self._selector.selectionChanged.disconnect(self._updateCurve) + self.__selector_is_connected = False + self._selector.setData(y) + self._selector.setAxisNames([ylabel or "Y"]) + + if len(y.shape) < 2: + self.selectorDock.hide() + else: + self.selectorDock.show() + + self._plot.setGraphTitle(title or "") + self._plot.setGraphXLabel(self.__axis_name or "X") + self._plot.setGraphYLabel(self.__signal_name or "Y") + self._updateCurve() + + if not self.__selector_is_connected: + self._selector.selectionChanged.connect(self._updateCurve) + self.__selector_is_connected = True + + def _updateCurve(self): + y = self._selector.selectedData() + x = self.__axis + if x is None: + x = numpy.arange(len(y)) + elif numpy.isscalar(x) or len(x) == 1: + # constant axis + x = x * numpy.ones_like(y) + elif len(x) == 2 and len(y) != 2: + # linear calibration a + b * x + x = x[0] + x[1] * numpy.arange(len(y)) + legend = self.__signal_name + "[" + for sl in self._selector.selection(): + if sl == slice(None): + legend += ":, " + else: + legend += str(sl) + ", " + legend = legend[:-2] + "]" + if self.__signal_errors is not None: + y_errors = self.__signal_errors[self._selector.selection()] + else: + y_errors = None + + self._plot.remove(kind=("curve", "scatter")) + + # values: x-y-v scatter + if self.__values is not None: + self._plot.addScatter(x, y, self.__values, + legend=legend, + xerror=self.__axis_errors, + yerror=y_errors) + + # x monotonically increasing: curve + elif numpy.all(numpy.diff(x) > 0): + self._plot.addCurve(x, y, legend=legend, + xerror=self.__axis_errors, + yerror=y_errors) + + # scatter + else: + self._plot.addScatter(x, y, value=numpy.ones_like(y), + legend=legend, + xerror=self.__axis_errors, + yerror=y_errors) + self._plot.resetZoom() + self._plot.setGraphXLabel(self.__axis_name) + self._plot.setGraphYLabel(self.__signal_name) + + def clear(self): + self._plot.clear() + + +class ArrayImagePlot(qt.QWidget): + """ + Widget for plotting an image from a multi-dimensional signal array + and two 1D axes array. + + The signal array can have an arbitrary number of dimensions, the only + limitation being that the last two dimensions must have the same length as + the axes arrays. + + Sliders are provided to select indices on the first (n - 2) dimensions of + the signal array, and the plot is updated to show the image corresponding + to the selection. + + If one or both of the axes does not have regularly spaced values, the + the image is plotted as a coloured scatter plot. + """ + def __init__(self, parent=None): + """ + + :param parent: Parent QWidget + """ + super(ArrayImagePlot, self).__init__(parent) + + self.__signal = None + self.__signal_name = None + self.__x_axis = None + self.__x_axis_name = None + self.__y_axis = None + self.__y_axis_name = None + + self._plot = Plot2D(self) + self._plot.setDefaultColormap( + {"name": "viridis", + "vmin": 0., "vmax": 1., # ignored (autoscale) but mandatory + "normalization": "linear", + "autoscale": True}) + + self.selectorDock = qt.QDockWidget("Data selector", self._plot) + # not closable + self.selectorDock.setFeatures(qt.QDockWidget.DockWidgetMovable | + qt.QDockWidget.DockWidgetFloatable) + self._legend = qt.QLabel(self) + self._selector = NumpyAxesSelector(self.selectorDock) + self._selector.setNamedAxesSelectorVisibility(False) + self.__selector_is_connected = False + + layout = qt.QVBoxLayout() + layout.addWidget(self._plot) + layout.addWidget(self._legend) + self.selectorDock.setWidget(self._selector) + self._plot.addTabbedDockWidget(self.selectorDock) + + self.setLayout(layout) + + def setImageData(self, signal, + x_axis=None, y_axis=None, + signal_name=None, + xlabel=None, ylabel=None, + title=None): + """ + + :param signal: n-D dataset, whose last 2 dimensions are used as the + image's values. + :param x_axis: 1-D dataset used as the image's x coordinates. If + provided, its lengths must be equal to the length of the last + dimension of ``signal``. + :param y_axis: 1-D dataset used as the image's y. If provided, + its lengths must be equal to the length of the 2nd to last + dimension of ``signal``. + :param signal_name: Label used in the legend + :param xlabel: Label for X axis + :param ylabel: Label for Y axis + :param title: Graph title + """ + if self.__selector_is_connected: + self._selector.selectionChanged.disconnect(self._updateImage) + self.__selector_is_connected = False + + self.__signal = signal + self.__signal_name = signal_name or "" + self.__x_axis = x_axis + self.__x_axis_name = xlabel + self.__y_axis = y_axis + self.__y_axis_name = ylabel + + self._selector.setData(signal) + self._selector.setAxisNames([ylabel or "Y", xlabel or "X"]) + + if len(signal.shape) < 3: + self.selectorDock.hide() + else: + self.selectorDock.show() + + self._plot.setGraphTitle(title or "") + self._plot.setGraphXLabel(self.__x_axis_name or "X") + self._plot.setGraphYLabel(self.__y_axis_name or "Y") + + self._updateImage() + + if not self.__selector_is_connected: + self._selector.selectionChanged.connect(self._updateImage) + self.__selector_is_connected = True + + def _updateImage(self): + legend = self.__signal_name + "[" + for sl in self._selector.selection(): + if sl == slice(None): + legend += ":, " + else: + legend += str(sl) + ", " + legend = legend[:-2] + "]" + self._legend.setText("Displayed data: " + legend) + + img = self._selector.selectedData() + x_axis = self.__x_axis + y_axis = self.__y_axis + + if x_axis is None and y_axis is None: + xcalib = NoCalibration() + ycalib = NoCalibration() + else: + if x_axis is None: + # no calibration + x_axis = numpy.arange(img.shape[-1]) + elif numpy.isscalar(x_axis) or len(x_axis) == 1: + # constant axis + x_axis = x_axis * numpy.ones((img.shape[-1], )) + elif len(x_axis) == 2: + # linear calibration + x_axis = x_axis[0] * numpy.arange(img.shape[-1]) + x_axis[1] + + if y_axis is None: + y_axis = numpy.arange(img.shape[-2]) + elif numpy.isscalar(y_axis) or len(y_axis) == 1: + y_axis = y_axis * numpy.ones((img.shape[-2], )) + elif len(y_axis) == 2: + y_axis = y_axis[0] * numpy.arange(img.shape[-2]) + y_axis[1] + + xcalib = ArrayCalibration(x_axis) + ycalib = ArrayCalibration(y_axis) + + self._plot.remove(kind=("scatter", "image")) + if xcalib.is_affine() and ycalib.is_affine(): + # regular image + xorigin, xscale = xcalib(0), xcalib.get_slope() + yorigin, yscale = ycalib(0), ycalib.get_slope() + origin = (xorigin, yorigin) + scale = (xscale, yscale) + + self._plot.addImage(img, legend=legend, + origin=origin, scale=scale) + else: + scatterx, scattery = numpy.meshgrid(x_axis, y_axis) + self._plot.addScatter(numpy.ravel(scatterx), + numpy.ravel(scattery), + numpy.ravel(img), + legend=legend) + self._plot.setGraphXLabel(self.__x_axis_name) + self._plot.setGraphYLabel(self.__y_axis_name) + self._plot.resetZoom() + + def clear(self): + self._plot.clear() + + +class ArrayStackPlot(qt.QWidget): + """ + Widget for plotting a n-D array (n >= 3) as a stack of images. + Three axis arrays can be provided to calibrate the axes. + + The signal array can have an arbitrary number of dimensions, the only + limitation being that the last 3 dimensions must have the same length as + the axes arrays. + + Sliders are provided to select indices on the first (n - 3) dimensions of + the signal array, and the plot is updated to load the stack corresponding + to the selection. + """ + def __init__(self, parent=None): + """ + + :param parent: Parent QWidget + """ + super(ArrayStackPlot, self).__init__(parent) + + self.__signal = None + self.__signal_name = None + # the Z, Y, X axes apply to the last three dimensions of the signal + # (in that order) + self.__z_axis = None + self.__z_axis_name = None + self.__y_axis = None + self.__y_axis_name = None + self.__x_axis = None + self.__x_axis_name = None + + self._stack_view = StackView(self) + self._hline = qt.QFrame(self) + self._hline.setFrameStyle(qt.QFrame.HLine) + self._hline.setFrameShadow(qt.QFrame.Sunken) + self._legend = qt.QLabel(self) + self._selector = NumpyAxesSelector(self) + self._selector.setNamedAxesSelectorVisibility(False) + self.__selector_is_connected = False + + layout = qt.QVBoxLayout() + layout.addWidget(self._stack_view) + layout.addWidget(self._hline) + layout.addWidget(self._legend) + layout.addWidget(self._selector) + + self.setLayout(layout) + + def setStackData(self, signal, + x_axis=None, y_axis=None, z_axis=None, + signal_name=None, + xlabel=None, ylabel=None, zlabel=None, + title=None): + """ + + :param signal: n-D dataset, whose last 3 dimensions are used as the + 3D stack values. + :param x_axis: 1-D dataset used as the image's x coordinates. If + provided, its lengths must be equal to the length of the last + dimension of ``signal``. + :param y_axis: 1-D dataset used as the image's y. If provided, + its lengths must be equal to the length of the 2nd to last + dimension of ``signal``. + :param z_axis: 1-D dataset used as the image's z. If provided, + its lengths must be equal to the length of the 3rd to last + dimension of ``signal``. + :param signal_name: Label used in the legend + :param xlabel: Label for X axis + :param ylabel: Label for Y axis + :param zlabel: Label for Z axis + :param title: Graph title + """ + if self.__selector_is_connected: + self._selector.selectionChanged.disconnect(self._updateStack) + self.__selector_is_connected = False + + self.__signal = signal + self.__signal_name = signal_name or "" + self.__x_axis = x_axis + self.__x_axis_name = xlabel + self.__y_axis = y_axis + self.__y_axis_name = ylabel + self.__z_axis = z_axis + self.__z_axis_name = zlabel + + self._selector.setData(signal) + self._selector.setAxisNames([ylabel or "Y", xlabel or "X", zlabel or "Z"]) + + self._stack_view.setGraphTitle(title or "") + # by default, the z axis is the image position (dimension not plotted) + self._stack_view.setGraphXLabel(self.__x_axis_name or "X") + self._stack_view.setGraphYLabel(self.__y_axis_name or "Y") + + self._updateStack() + + ndims = len(signal.shape) + self._stack_view.setFirstStackDimension(ndims - 3) + + # the legend label shows the selection slice producing the volume + # (only interesting for ndim > 3) + if ndims > 3: + self._selector.setVisible(True) + self._legend.setVisible(True) + self._hline.setVisible(True) + else: + self._selector.setVisible(False) + self._legend.setVisible(False) + self._hline.setVisible(False) + + if not self.__selector_is_connected: + self._selector.selectionChanged.connect(self._updateStack) + self.__selector_is_connected = True + + @staticmethod + def _get_origin_scale(axis): + """Assuming axis is a regularly spaced 1D array, + return a tuple (origin, scale) where: + - origin = axis[0] + - scale = (axis[n-1] - axis[0]) / (n -1) + :param axis: 1D numpy array + :return: Tuple (axis[0], (axis[-1] - axis[0]) / (len(axis) - 1)) + """ + return axis[0], (axis[-1] - axis[0]) / (len(axis) - 1) + + def _updateStack(self): + """Update displayed stack according to the current axes selector + data.""" + stk = self._selector.selectedData() + x_axis = self.__x_axis + y_axis = self.__y_axis + z_axis = self.__z_axis + + calibrations = [] + for axis in [z_axis, y_axis, x_axis]: + + if axis is None: + calibrations.append(NoCalibration()) + elif len(axis) == 2: + calibrations.append( + LinearCalibration(y_intercept=axis[0], + slope=axis[1])) + else: + calibrations.append(ArrayCalibration(axis)) + + legend = self.__signal_name + "[" + for sl in self._selector.selection(): + if sl == slice(None): + legend += ":, " + else: + legend += str(sl) + ", " + legend = legend[:-2] + "]" + self._legend.setText("Displayed data: " + legend) + + self._stack_view.setStack(stk, calibrations=calibrations) + self._stack_view.setLabels( + labels=[self.__z_axis_name, + self.__y_axis_name, + self.__x_axis_name]) + + def clear(self): + self._stack_view.clear() diff --git a/silx/gui/data/NumpyAxesSelector.py b/silx/gui/data/NumpyAxesSelector.py new file mode 100644 index 0000000..f4641da --- /dev/null +++ b/silx/gui/data/NumpyAxesSelector.py @@ -0,0 +1,468 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module defines a widget able to convert a numpy array from n-dimensions +to a numpy array with less dimensions. +""" +from __future__ import division + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "16/01/2017" + +import numpy +import functools +from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser +from silx.gui import qt +import silx.utils.weakref + + +class _Axis(qt.QWidget): + """Widget displaying an axis. + + It allows to display and scroll in the axis, and provide a widget to + map the axis with a named axis (the one from the view). + """ + + valueChanged = qt.Signal(int) + """Emitted when the location on the axis change.""" + + axisNameChanged = qt.Signal(object) + """Emitted when the user change the name of the axis.""" + + def __init__(self, parent=None): + """Constructor + + :param parent: Parent of the widget + """ + super(_Axis, self).__init__(parent) + self.__axisNumber = None + self.__customAxisNames = set([]) + self.__label = qt.QLabel(self) + self.__axes = qt.QComboBox(self) + self.__axes.currentIndexChanged[int].connect(self.__axisMappingChanged) + self.__slider = HorizontalSliderWithBrowser(self) + self.__slider.valueChanged[int].connect(self.__sliderValueChanged) + layout = qt.QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self.__label) + layout.addWidget(self.__axes) + layout.addWidget(self.__slider, 10000) + layout.addStretch(1) + self.setLayout(layout) + + def slider(self): + """Returns the slider used to display axes location. + + :rtype: HorizontalSliderWithBrowser + """ + return self.__slider + + def setAxis(self, number, position, size): + """Set axis information. + + :param int number: The number of the axis (from the original numpy + array) + :param int position: The current position in the axis (for a slicing) + :param int size: The size of this axis (0..n) + """ + self.__label.setText("Dimension %s" % number) + self.__axisNumber = number + self.__slider.setMaximum(size - 1) + + def axisNumber(self): + """Returns the axis number. + + :rtype: int + """ + return self.__axisNumber + + def setAxisName(self, axisName): + """Set the current used axis name. + + If this name is not available an exception is raised. An empty string + means that no name is selected. + + :param str axisName: The new name of the axis + :raise ValueError: When the name is not available + """ + if axisName == "" and self.__axes.count() == 0: + self.__axes.setCurrentIndex(-1) + self.__updateSliderVisibility() + for index in range(self.__axes.count()): + name = self.__axes.itemData(index) + if name == axisName: + self.__axes.setCurrentIndex(index) + self.__updateSliderVisibility() + return + raise ValueError("Axis name '%s' not found", axisName) + + def axisName(self): + """Returns the selected axis name. + + If no names are selected, an empty string is retruned. + + :rtype: str + """ + index = self.__axes.currentIndex() + if index == -1: + return "" + return self.__axes.itemData(index) + + def setAxisNames(self, axesNames): + """Set the available list of names for the axis. + + :param list[str] axesNames: List of available names + """ + self.__axes.clear() + previous = self.__axes.blockSignals(True) + self.__axes.addItem(" ", "") + for axis in axesNames: + self.__axes.addItem(axis, axis) + self.__axes.blockSignals(previous) + self.__updateSliderVisibility() + + def setCustomAxis(self, axesNames): + """Set the available list of named axis which can be set to a value. + + :param list[str] axesNames: List of customable axis names + """ + self.__customAxisNames = set(axesNames) + self.__updateSliderVisibility() + + def __axisMappingChanged(self, index): + """Called when the selected name change. + + :param int index: Selected index + """ + self.__updateSliderVisibility() + name = self.axisName() + self.axisNameChanged.emit(name) + + def __updateSliderVisibility(self): + """Update the visibility of the slider according to axis names and + customable axis names.""" + name = self.axisName() + isVisible = name == "" or name in self.__customAxisNames + self.__slider.setVisible(isVisible) + + def value(self): + """Returns the current selected position in the axis. + + :rtype: int + """ + return self.__slider.value() + + def __sliderValueChanged(self, value): + """Called when the selected position in the axis change. + + :param int value: Position of the axis + """ + self.valueChanged.emit(value) + + def setNamedAxisSelectorVisibility(self, visible): + """Hide or show the named axis combobox. + If both the selector and the slider are hidden, + hide the entire widget. + + :param visible: boolean + """ + self.__axes.setVisible(visible) + name = self.axisName() + + if not visible and name != "": + self.setVisible(False) + else: + self.setVisible(True) + + +class NumpyAxesSelector(qt.QWidget): + """Widget to select a view from a numpy array. + + .. image:: img/NumpyAxesSelector.png + + The widget is set with an input data using :meth:`setData`, and a requested + output dimension using :meth:`setAxisNames`. + + Widgets are provided to selected expected input axis, and a slice on the + non-selected axis. + + The final selected array can be reached using the getter + :meth:`selectedData`, and the event `selectionChanged`. + + If the input data is a HDF5 Dataset, the selected output data will be a + new numpy array. + """ + + dataChanged = qt.Signal() + """Emitted when the input data change""" + + selectedAxisChanged = qt.Signal() + """Emitted when the selected axis change""" + + selectionChanged = qt.Signal() + """Emitted when the selected data change""" + + customAxisChanged = qt.Signal(str, int) + """Emitted when a custom axis change""" + + def __init__(self, parent=None): + """Constructor + + :param parent: Parent of the widget + """ + super(NumpyAxesSelector, self).__init__(parent) + + self.__data = None + self.__selectedData = None + self.__selection = tuple() + self.__axis = [] + self.__axisNames = [] + self.__customAxisNames = set([]) + self.__namedAxesVisibility = True + layout = qt.QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSizeConstraint(qt.QLayout.SetMinAndMaxSize) + self.setLayout(layout) + + def clear(self): + """Clear the widget.""" + self.setData(None) + + def setAxisNames(self, axesNames): + """Set the axis names of the output selected data. + + Axis names are defined from slower to faster axis. + + The size of the list will constrain the dimension of the resulting + array. + + :param list[str] axesNames: List of string identifying axis names + """ + self.__axisNames = list(axesNames) + delta = len(self.__axis) - len(self.__axisNames) + if delta < 0: + delta = 0 + for index, axis in enumerate(self.__axis): + previous = axis.blockSignals(True) + axis.setAxisNames(self.__axisNames) + if index >= delta and index - delta < len(self.__axisNames): + axis.setAxisName(self.__axisNames[index - delta]) + else: + axis.setAxisName("") + axis.blockSignals(previous) + self.__updateSelectedData() + + def setCustomAxis(self, axesNames): + """Set the available list of named axis which can be set to a value. + + :param list[str] axesNames: List of customable axis names + """ + self.__customAxisNames = set(axesNames) + for axis in self.__axis: + axis.setCustomAxis(self.__customAxisNames) + + def setData(self, data): + """Set the input data unsed by the widget. + + :param numpy.ndarray data: The input data + """ + if self.__data is not None: + # clean up + for widget in self.__axis: + self.layout().removeWidget(widget) + widget.deleteLater() + self.__axis = [] + + self.__data = data + + if data is not None: + # create expected axes + dimensionNumber = len(data.shape) + delta = dimensionNumber - len(self.__axisNames) + for index in range(dimensionNumber): + axis = _Axis(self) + axis.setAxis(index, 0, data.shape[index]) + axis.setAxisNames(self.__axisNames) + axis.setCustomAxis(self.__customAxisNames) + if index >= delta and index - delta < len(self.__axisNames): + axis.setAxisName(self.__axisNames[index - delta]) + # this weak method was expected to be able to delete sub widget + callback = functools.partial(silx.utils.weakref.WeakMethodProxy(self.__axisValueChanged), axis) + axis.valueChanged.connect(callback) + # this weak method was expected to be able to delete sub widget + callback = functools.partial(silx.utils.weakref.WeakMethodProxy(self.__axisNameChanged), axis) + axis.axisNameChanged.connect(callback) + axis.setNamedAxisSelectorVisibility(self.__namedAxesVisibility) + self.layout().addWidget(axis) + self.__axis.append(axis) + self.__normalizeAxisGeometry() + + self.dataChanged.emit() + self.__updateSelectedData() + + def __normalizeAxisGeometry(self): + """Update axes geometry to align all axes components together.""" + if len(self.__axis) <= 0: + return + lineEditWidth = max([a.slider().lineEdit().minimumSize().width() for a in self.__axis]) + limitWidth = max([a.slider().limitWidget().minimumSizeHint().width() for a in self.__axis]) + for a in self.__axis: + a.slider().lineEdit().setFixedWidth(lineEditWidth) + a.slider().limitWidget().setFixedWidth(limitWidth) + + def __axisValueChanged(self, axis, value): + name = axis.axisName() + if name in self.__customAxisNames: + self.customAxisChanged.emit(name, value) + else: + self.__updateSelectedData() + + def __axisNameChanged(self, axis, name): + """Called when an axis name change. + + :param _Axis axis: The changed axis + :param str name: The new name of the axis + """ + names = [x.axisName() for x in self.__axis] + missingName = set(self.__axisNames) - set(names) - set("") + if len(missingName) == 0: + missingName = None + elif len(missingName) == 1: + missingName = list(missingName)[0] + else: + raise Exception("Unexpected state") + + axisChanged = True + + if axis.axisName() == "": + # set the removed label to another widget if it is possible + availableWidget = None + for widget in self.__axis: + if widget is axis: + continue + if widget.axisName() == "": + availableWidget = widget + break + if availableWidget is None: + # If there is no other solution we set the name at the same place + axisChanged = False + availableWidget = axis + previous = availableWidget.blockSignals(True) + availableWidget.setAxisName(missingName) + availableWidget.blockSignals(previous) + else: + # there is a duplicated name somewhere + # we swap it with the missing name or with nothing + dupWidget = None + for widget in self.__axis: + if widget is axis: + continue + if widget.axisName() == axis.axisName(): + dupWidget = widget + break + if missingName is None: + missingName = "" + previous = dupWidget.blockSignals(True) + dupWidget.setAxisName(missingName) + dupWidget.blockSignals(previous) + + if self.__data is None: + return + if axisChanged: + self.selectedAxisChanged.emit() + self.__updateSelectedData() + + def __updateSelectedData(self): + """Update the selected data according to the state of the widget. + + It fires a `selectionChanged` event. + """ + if self.__data is None: + if self.__selectedData is not None: + self.__selectedData = None + self.__selection = tuple() + self.selectionChanged.emit() + return + + selection = [] + axisNames = [] + for slider in self.__axis: + name = slider.axisName() + if name == "": + selection.append(slider.value()) + else: + selection.append(slice(None)) + axisNames.append(name) + + self.__selection = tuple(selection) + # get a view with few fixed dimensions + # with a h5py dataset, it create a copy + # TODO we can reuse the same memory in case of a copy + view = self.__data[self.__selection] + + # order axis as expected + source = [] + destination = [] + order = [] + for index, name in enumerate(self.__axisNames): + destination.append(index) + source.append(axisNames.index(name)) + for _, s in sorted(zip(destination, source)): + order.append(s) + view = numpy.transpose(view, order) + + self.__selectedData = view + self.selectionChanged.emit() + + def data(self): + """Returns the input data. + + :rtype: numpy.ndarray + """ + return self.__data + + def selectedData(self): + """Returns the output data. + + :rtype: numpy.ndarray + """ + return self.__selectedData + + def selection(self): + """Returns the selection tuple used to slice the data. + + :rtype: tuple + """ + return self.__selection + + def setNamedAxesSelectorVisibility(self, visible): + """Show or hide the combo-boxes allowing to map the plot axes + to the data dimension. + + :param visible: Boolean + """ + self.__namedAxesVisibility = visible + for axis in self.__axis: + axis.setNamedAxisSelectorVisibility(visible) diff --git a/silx/gui/data/RecordTableView.py b/silx/gui/data/RecordTableView.py new file mode 100644 index 0000000..ce6a178 --- /dev/null +++ b/silx/gui/data/RecordTableView.py @@ -0,0 +1,405 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +This module define model and widget to display 1D slices from numpy +array using compound data types or hdf5 databases. +""" +from __future__ import division + +import itertools +import numpy +from silx.gui import qt +import silx.io +from .TextFormatter import TextFormatter +from silx.gui.widgets.TableWidget import CopySelectedCellsAction + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "27/01/2017" + + +class _MultiLineItem(qt.QItemDelegate): + """Draw a multiline text without hiding anything. + + The paint method display a cell without any wrap. And an editor is + available to scroll into the selected cell. + """ + + def __init__(self, parent=None): + """ + Constructor + + :param qt.QWidget parent: Parent of the widget + """ + qt.QItemDelegate.__init__(self, parent) + self.__textOptions = qt.QTextOption() + self.__textOptions.setFlags(qt.QTextOption.IncludeTrailingSpaces | + qt.QTextOption.ShowTabsAndSpaces) + self.__textOptions.setWrapMode(qt.QTextOption.NoWrap) + self.__textOptions.setAlignment(qt.Qt.AlignTop | qt.Qt.AlignLeft) + + def paint(self, painter, option, index): + """ + Write multiline text without using any wrap or any alignment according + to the cell size. + + :param qt.QPainter painter: Painter context used to displayed the cell + :param qt.QStyleOptionViewItem option: Control how the editor is shown + :param qt.QIndex index: Index of the data to display + """ + painter.save() + + # set colors + painter.setPen(qt.QPen(qt.Qt.NoPen)) + if option.state & qt.QStyle.State_Selected: + brush = option.palette.highlight() + painter.setBrush(brush) + else: + brush = index.data(qt.Qt.BackgroundRole) + if brush is None: + # default background color for a cell + brush = qt.Qt.white + painter.setBrush(brush) + painter.drawRect(option.rect) + + if index.isValid(): + if option.state & qt.QStyle.State_Selected: + brush = option.palette.highlightedText() + else: + brush = index.data(qt.Qt.ForegroundRole) + if brush is None: + brush = option.palette.text() + painter.setPen(qt.QPen(brush.color())) + text = index.data(qt.Qt.DisplayRole) + painter.drawText(qt.QRectF(option.rect), text, self.__textOptions) + + painter.restore() + + def createEditor(self, parent, option, index): + """ + Returns the widget used to edit the item specified by index for editing. + + We use it not to edit the content but to show the content with a + convenient scroll bar. + + :param qt.QWidget parent: Parent of the widget + :param qt.QStyleOptionViewItem option: Control how the editor is shown + :param qt.QIndex index: Index of the data to display + """ + if not index.isValid(): + return super(_MultiLineItem, self).createEditor(parent, option, index) + + editor = qt.QTextEdit(parent) + editor.setReadOnly(True) + return editor + + def setEditorData(self, editor, index): + """ + Read data from the model and feed the editor. + + :param qt.QWidget editor: Editor widget + :param qt.QIndex index: Index of the data to display + """ + text = index.model().data(index, qt.Qt.EditRole) + editor.setText(text) + + def updateEditorGeometry(self, editor, option, index): + """ + Update the geometry of the editor according to the changes of the view. + + :param qt.QWidget editor: Editor widget + :param qt.QStyleOptionViewItem option: Control how the editor is shown + :param qt.QIndex index: Index of the data to display + """ + editor.setGeometry(option.rect) + + +class RecordTableModel(qt.QAbstractTableModel): + """This data model provides access to 1D slices from numpy array using + compound data types or hdf5 databases. + + Each entries are displayed in a single row, and each columns contain a + specific field of the compound type. + + It also allows to display 1D arrays of simple data types. + array. + + :param qt.QObject parent: Parent object + :param numpy.ndarray data: A numpy array or a h5py dataset + """ + def __init__(self, parent=None, data=None): + qt.QAbstractTableModel.__init__(self, parent) + + self.__data = None + self.__is_array = False + self.__fields = None + self.__formatter = None + self.__editFormatter = None + self.setFormatter(TextFormatter(self)) + + # set _data + self.setArrayData(data) + + # Methods to be implemented to subclass QAbstractTableModel + def rowCount(self, parent_idx=None): + """Returns number of rows to be displayed in table""" + if self.__data is None: + return 0 + elif not self.__is_array: + return 1 + else: + return len(self.__data) + + def columnCount(self, parent_idx=None): + """Returns number of columns to be displayed in table""" + if self.__fields is None: + return 1 + else: + return len(self.__fields) + + def data(self, index, role=qt.Qt.DisplayRole): + """QAbstractTableModel method to access data values + in the format ready to be displayed""" + if not index.isValid(): + return None + + if self.__data is None: + return None + + if self.__is_array: + if index.row() >= len(self.__data): + return None + data = self.__data[index.row()] + else: + if index.row() > 0: + return None + data = self.__data + + if self.__fields is not None: + if index.column() >= len(self.__fields): + return None + key = self.__fields[index.column()][1] + data = data[key[0]] + if len(key) > 1: + data = data[key[1]] + + if role == qt.Qt.DisplayRole: + return self.__formatter.toString(data) + elif role == qt.Qt.EditRole: + return self.__editFormatter.toString(data) + return None + + def headerData(self, section, orientation, role=qt.Qt.DisplayRole): + """Returns the 0-based row or column index, for display in the + horizontal and vertical headers""" + if section == -1: + # PyQt4 send -1 when there is columns but no rows + return None + + if role == qt.Qt.DisplayRole: + if orientation == qt.Qt.Vertical: + if not self.__is_array: + return "Scalar" + else: + return str(section) + if orientation == qt.Qt.Horizontal: + if self.__fields is None: + if section == 0: + return "Data" + else: + return None + else: + if section < len(self.__fields): + return self.__fields[section][0] + else: + return None + return None + + def flags(self, index): + """QAbstractTableModel method to inform the view whether data + is editable or not. + """ + return qt.QAbstractTableModel.flags(self, index) + + def setArrayData(self, data): + """Set the data array and the viewing perspective. + + You can set ``copy=False`` if you need more performances, when dealing + with a large numpy array. In this case, a simple reference to the data + is used to access the data, rather than a copy of the array. + + .. warning:: + + Any change to the data model will affect your original data + array, when using a reference rather than a copy.. + + :param data: 1D numpy array, or any object that can be + converted to a numpy array using ``numpy.array(data)`` (e.g. + a nested sequence). + """ + if qt.qVersion() > "4.6": + self.beginResetModel() + + self.__data = data + if isinstance(data, numpy.ndarray): + self.__is_array = True + elif silx.io.is_dataset(data) and data.shape != tuple(): + self.__is_array = True + else: + self.__is_array = False + + + self.__fields = [] + if data is not None: + if data.dtype.fields is not None: + for name, (dtype, _index) in data.dtype.fields.items(): + if dtype.shape != tuple(): + keys = itertools.product(*[range(x) for x in dtype.shape]) + for key in keys: + label = "%s%s" % (name, list(key)) + array_key = (name, key) + self.__fields.append((label, array_key)) + else: + self.__fields.append((name, (name,))) + else: + self.__fields = None + + if qt.qVersion() > "4.6": + self.endResetModel() + else: + self.reset() + + def arrayData(self): + """Returns the internal data. + + :rtype: numpy.ndarray of h5py.Dataset + """ + return self.__data + + def setFormatter(self, formatter): + """Set the formatter object to be used to display data from the model + + :param TextFormatter formatter: Formatter to use + """ + if formatter is self.__formatter: + return + + if qt.qVersion() > "4.6": + self.beginResetModel() + + if self.__formatter is not None: + self.__formatter.formatChanged.disconnect(self.__formatChanged) + + self.__formatter = formatter + self.__editFormatter = TextFormatter(formatter) + self.__editFormatter.setUseQuoteForText(False) + + if self.__formatter is not None: + self.__formatter.formatChanged.connect(self.__formatChanged) + + if qt.qVersion() > "4.6": + self.endResetModel() + else: + self.reset() + + def getFormatter(self): + """Returns the text formatter used. + + :rtype: TextFormatter + """ + return self.__formatter + + def __formatChanged(self): + """Called when the format changed. + """ + self.__editFormatter = TextFormatter(self, self.getFormatter()) + self.__editFormatter.setUseQuoteForText(False) + self.reset() + + +class _ShowEditorProxyModel(qt.QIdentityProxyModel): + """ + Allow to custom the flag edit of the model + """ + + def __init__(self, parent=None): + """ + Constructor + + :param qt.QObject arent: parent object + """ + super(_ShowEditorProxyModel, self).__init__(parent) + self.__forceEditable = False + + def flags(self, index): + flag = qt.QIdentityProxyModel.flags(self, index) + if self.__forceEditable: + flag = flag | qt.Qt.ItemIsEditable + return flag + + def forceCellEditor(self, show): + """ + Enable the editable flag to allow to display cell editor. + """ + if self.__forceEditable == show: + return + self.beginResetModel() + self.__forceEditable = show + self.endResetModel() + + +class RecordTableView(qt.QTableView): + """TableView using DatabaseTableModel as default model. + """ + def __init__(self, parent=None): + """ + Constructor + + :param qt.QWidget parent: parent QWidget + """ + qt.QTableView.__init__(self, parent) + + model = _ShowEditorProxyModel(self) + model.setSourceModel(RecordTableModel()) + self.setModel(model) + self.__multilineView = _MultiLineItem(self) + self.setEditTriggers(qt.QAbstractItemView.AllEditTriggers) + self._copyAction = CopySelectedCellsAction(self) + self.addAction(self._copyAction) + + def copy(self): + self._copyAction.trigger() + + def setArrayData(self, data): + self.model().sourceModel().setArrayData(data) + if data is not None: + if issubclass(data.dtype.type, (numpy.string_, numpy.unicode_)): + # TODO it would be nice to also fix fields + # but using it only for string array is already very useful + self.setItemDelegateForColumn(0, self.__multilineView) + self.model().forceCellEditor(True) + else: + self.setItemDelegateForColumn(0, None) + self.model().forceCellEditor(False) diff --git a/silx/gui/data/TextFormatter.py b/silx/gui/data/TextFormatter.py new file mode 100644 index 0000000..f074de5 --- /dev/null +++ b/silx/gui/data/TextFormatter.py @@ -0,0 +1,222 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This package provides a class sharred by widget from the +data module to format data as text in the same way.""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "26/04/2017" + +import numpy +import numbers +import binascii +from silx.third_party import six +from silx.gui import qt + + +class TextFormatter(qt.QObject): + """Formatter to convert data to string. + + The method :meth:`toString` returns a formatted string from an input data + using parameters set to this object. + + It support most python and numpy data, expecting dictionary. Unsupported + data are displayed using the string representation of the object (`str`). + + It provides a set of parameters to custom the formatting of integer and + float values (:meth:`setIntegerFormat`, :meth:`setFloatFormat`). + + It also allows to custom the use of quotes to display text data + (:meth:`setUseQuoteForText`), and custom unit used to display imaginary + numbers (:meth:`setImaginaryUnit`). + + The object emit an event `formatChanged` every time a parametter is + changed. + """ + + formatChanged = qt.Signal() + """Emitted when properties of the formatter change.""" + + def __init__(self, parent=None, formatter=None): + """ + Constructor + + :param qt.QObject parent: Owner of the object + :param TextFormatter formatter: Instantiate this object from the + formatter + """ + qt.QObject.__init__(self, parent) + if formatter is not None: + self.__integerFormat = formatter.integerFormat() + self.__floatFormat = formatter.floatFormat() + self.__useQuoteForText = formatter.useQuoteForText() + self.__imaginaryUnit = formatter.imaginaryUnit() + else: + self.__integerFormat = "%d" + self.__floatFormat = "%g" + self.__useQuoteForText = True + self.__imaginaryUnit = u"j" + + def integerFormat(self): + """Returns the format string controlling how the integer data + are formated by this object. + + This is the C-style format string used by python when formatting + strings with the modulus operator. + + :rtype: str + """ + return self.__integerFormat + + def setIntegerFormat(self, value): + """Set format string controlling how the integer data are + formated by this object. + + :param str value: Format string (e.g. "%d", "%i", "%08i"). + This is the C-style format string used by python when formatting + strings with the modulus operator. + """ + if self.__integerFormat == value: + return + self.__integerFormat = value + self.formatChanged.emit() + + def floatFormat(self): + """Returns the format string controlling how the floating-point data + are formated by this object. + + This is the C-style format string used by python when formatting + strings with the modulus operator. + + :rtype: str + """ + return self.__floatFormat + + def setFloatFormat(self, value): + """Set format string controlling how the floating-point data are + formated by this object. + + :param str value: Format string (e.g. "%.3f", "%d", "%-10.2f", + "%10.3e"). + This is the C-style format string used by python when formatting + strings with the modulus operator. + """ + if self.__floatFormat == value: + return + self.__floatFormat = value + self.formatChanged.emit() + + def useQuoteForText(self): + """Returns true if the string data are formatted using double quotes. + + Else, no quotes are used. + """ + return self.__integerFormat + + def setUseQuoteForText(self, useQuote): + """Set the use of quotes to delimit string data. + + :param bool useQuote: True to use quotes. + """ + if self.__useQuoteForText == useQuote: + return + self.__useQuoteForText = useQuote + self.formatChanged.emit() + + def imaginaryUnit(self): + """Returns the unit display for imaginary numbers. + + :rtype: str + """ + return self.__imaginaryUnit + + def setImaginaryUnit(self, imaginaryUnit): + """Set the unit display for imaginary numbers. + + :param str imaginaryUnit: Unit displayed after imaginary numbers + """ + if self.__imaginaryUnit == imaginaryUnit: + return + self.__imaginaryUnit = imaginaryUnit + self.formatChanged.emit() + + def toString(self, data): + """Format a data into a string using formatter options + + :param object data: Data to render + :rtype: str + """ + if isinstance(data, tuple): + text = [self.toString(d) for d in data] + return "(" + " ".join(text) + ")" + elif isinstance(data, (list, numpy.ndarray)): + text = [self.toString(d) for d in data] + return "[" + " ".join(text) + "]" + elif isinstance(data, numpy.void): + dtype = data.dtype + if data.dtype.fields is not None: + text = [self.toString(data[f]) for f in dtype.fields] + return "(" + " ".join(text) + ")" + return "0x" + binascii.hexlify(data).decode("ascii") + elif isinstance(data, (numpy.string_, numpy.object_, bytes)): + # This have to be done before checking python string inheritance + try: + text = "%s" % data.decode("utf-8") + if self.__useQuoteForText: + text = "\"%s\"" % text.replace("\"", "\\\"") + return text + except UnicodeDecodeError: + pass + return "0x" + binascii.hexlify(data).decode("ascii") + elif isinstance(data, six.string_types): + text = "%s" % data + if self.__useQuoteForText: + text = "\"%s\"" % text.replace("\"", "\\\"") + return text + elif isinstance(data, (numpy.integer, numbers.Integral)): + return self.__integerFormat % data + elif isinstance(data, (numbers.Real, numpy.floating)): + # It have to be done before complex checking + return self.__floatFormat % data + elif isinstance(data, (numpy.complex_, numbers.Complex)): + text = "" + if data.real != 0: + text += self.__floatFormat % data.real + if data.real != 0 and data.imag != 0: + if data.imag < 0: + template = self.__floatFormat + " - " + self.__floatFormat + self.__imaginaryUnit + params = (data.real, -data.imag) + else: + template = self.__floatFormat + " + " + self.__floatFormat + self.__imaginaryUnit + params = (data.real, data.imag) + else: + if data.imag != 0: + template = self.__floatFormat + self.__imaginaryUnit + params = (data.imag) + else: + template = self.__floatFormat + params = (data.real) + return template % params + return str(data) diff --git a/silx/gui/data/__init__.py b/silx/gui/data/__init__.py new file mode 100644 index 0000000..560062d --- /dev/null +++ b/silx/gui/data/__init__.py @@ -0,0 +1,35 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This package provides a set of Qt widgets for displaying data arrays using +table views and plot widgets. + +.. note:: + + Widgets in this package may rely on additional dependencies that are + not mandatory for *silx*. + :class:`DataViewer.DataViewer` relies on :mod:`silx.gui.plot` which + depends on *matplotlib*. It also optionally depends on *PyOpenGL* for 3D + visualization. +""" diff --git a/silx/gui/data/setup.py b/silx/gui/data/setup.py new file mode 100644 index 0000000..23ccbdd --- /dev/null +++ b/silx/gui/data/setup.py @@ -0,0 +1,41 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "16/01/2017" + + +from numpy.distutils.misc_util import Configuration + + +def configuration(parent_package='', top_path=None): + config = Configuration('data', parent_package, top_path) + config.add_subpackage('test') + return config + + +if __name__ == "__main__": + from numpy.distutils.core import setup + setup(configuration=configuration) diff --git a/silx/gui/data/test/__init__.py b/silx/gui/data/test/__init__.py new file mode 100644 index 0000000..08c044b --- /dev/null +++ b/silx/gui/data/test/__init__.py @@ -0,0 +1,45 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +import unittest + +from . import test_arraywidget +from . import test_numpyaxesselector +from . import test_dataviewer +from . import test_textformatter + +__authors__ = ["V. Valls", "P. Knobel"] +__license__ = "MIT" +__date__ = "24/01/2017" + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTests( + [test_arraywidget.suite(), + test_numpyaxesselector.suite(), + test_dataviewer.suite(), + test_textformatter.suite(), + ]) + return test_suite diff --git a/silx/gui/data/test/test_arraywidget.py b/silx/gui/data/test/test_arraywidget.py new file mode 100644 index 0000000..bbd7ee5 --- /dev/null +++ b/silx/gui/data/test/test_arraywidget.py @@ -0,0 +1,320 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "05/12/2016" + +import os +import tempfile +import unittest + +import numpy + +from silx.gui import qt +from silx.gui.data import ArrayTableWidget +from silx.gui.test.utils import TestCaseQt + +try: + import h5py +except ImportError: + h5py = None + + +class TestArrayWidget(TestCaseQt): + """Basic test for ArrayTableWidget with a numpy array""" + def setUp(self): + super(TestArrayWidget, self).setUp() + self.aw = ArrayTableWidget.ArrayTableWidget() + + def tearDown(self): + del self.aw + super(TestArrayWidget, self).tearDown() + + def testShow(self): + """test for errors""" + self.aw.show() + self.qWaitForWindowExposed(self.aw) + + def testSetData0D(self): + a = 1 + self.aw.setArrayData(a) + b = self.aw.getData(copy=True) + + self.assertTrue(numpy.array_equal(a, b)) + + # scalar/0D data has no frame index + self.assertEqual(len(self.aw.model._index), 0) + # and no perspective + self.assertEqual(len(self.aw.model._perspective), 0) + + def testSetData1D(self): + a = [1, 2] + self.aw.setArrayData(a) + b = self.aw.getData(copy=True) + + self.assertTrue(numpy.array_equal(a, b)) + + # 1D data has no frame index + self.assertEqual(len(self.aw.model._index), 0) + # and no perspective + self.assertEqual(len(self.aw.model._perspective), 0) + + def testSetData4D(self): + a = numpy.reshape(numpy.linspace(0.213, 1.234, 1250), + (5, 5, 5, 10)) + self.aw.setArrayData(a) + + # default perspective (0, 1) + self.assertEqual(list(self.aw.model._perspective), + [0, 1]) + self.aw.setPerspective((1, 3)) + self.assertEqual(list(self.aw.model._perspective), + [1, 3]) + + b = self.aw.getData(copy=True) + self.assertTrue(numpy.array_equal(a, b)) + + # 4D data has a 2-tuple as frame index + self.assertEqual(len(self.aw.model._index), 2) + # default index is (0, 0) + self.assertEqual(list(self.aw.model._index), + [0, 0]) + self.aw.setFrameIndex((3, 1)) + + self.assertEqual(list(self.aw.model._index), + [3, 1]) + + def testColors(self): + a = numpy.arange(256, dtype=numpy.uint8) + self.aw.setArrayData(a) + + bgcolor = numpy.empty(a.shape + (3,), dtype=numpy.uint8) + # Black & white palette + bgcolor[..., 0] = a + bgcolor[..., 1] = a + bgcolor[..., 2] = a + + fgcolor = numpy.bitwise_xor(bgcolor, 255) + + self.aw.setArrayColors(bgcolor, fgcolor) + + # test colors are as expected in model + for i in range(256): + # all RGB channels for BG equal to data value + self.assertEqual( + self.aw.model.data(self.aw.model.index(0, i), + role=qt.Qt.BackgroundRole), + qt.QColor(i, i, i), + "Unexpected background color" + ) + + # all RGB channels for FG equal to XOR(data value, 255) + self.assertEqual( + self.aw.model.data(self.aw.model.index(0, i), + role=qt.Qt.ForegroundRole), + qt.QColor(i ^ 255, i ^ 255, i ^ 255), + "Unexpected text color" + ) + + # test colors are reset to None when a new data array is loaded + # with different shape + self.aw.setArrayData(numpy.arange(300)) + + for i in range(300): + # all RGB channels for BG equal to data value + self.assertIsNone( + self.aw.model.data(self.aw.model.index(0, i), + role=qt.Qt.BackgroundRole)) + + def testDefaultFlagNotEditable(self): + """editable should be False by default, in setArrayData""" + self.aw.setArrayData([[0]]) + idx = self.aw.model.createIndex(0, 0) + # model is editable + self.assertFalse( + self.aw.model.flags(idx) & qt.Qt.ItemIsEditable) + + def testFlagEditable(self): + self.aw.setArrayData([[0]], editable=True) + idx = self.aw.model.createIndex(0, 0) + # model is editable + self.assertTrue( + self.aw.model.flags(idx) & qt.Qt.ItemIsEditable) + + def testFlagNotEditable(self): + self.aw.setArrayData([[0]], editable=False) + idx = self.aw.model.createIndex(0, 0) + # model is editable + self.assertFalse( + self.aw.model.flags(idx) & qt.Qt.ItemIsEditable) + + def testReferenceReturned(self): + """when setting the data with copy=False and + retrieving it with getData(copy=False), we should recover + the same original object. + """ + # n-D (n >=2) + a0 = numpy.reshape(numpy.linspace(0.213, 1.234, 1000), + (10, 10, 10)) + self.aw.setArrayData(a0, copy=False) + a1 = self.aw.getData(copy=False) + + self.assertIs(a0, a1) + + # 1D + b0 = numpy.linspace(0.213, 1.234, 1000) + self.aw.setArrayData(b0, copy=False) + b1 = self.aw.getData(copy=False) + self.assertIs(b0, b1) + + +@unittest.skipIf(h5py is None, "Could not import h5py") +class TestH5pyArrayWidget(TestCaseQt): + """Basic test for ArrayTableWidget with a dataset. + + Test flags, for dataset open in read-only or read-write modes""" + def setUp(self): + super(TestH5pyArrayWidget, self).setUp() + self.aw = ArrayTableWidget.ArrayTableWidget() + self.data = numpy.reshape(numpy.linspace(0.213, 1.234, 1000), + (10, 10, 10)) + # create an h5py file with a dataset + self.tempdir = tempfile.mkdtemp() + self.h5_fname = os.path.join(self.tempdir, "array.h5") + h5f = h5py.File(self.h5_fname) + h5f["my_array"] = self.data + h5f["my_scalar"] = 3.14 + h5f["my_1D_array"] = numpy.array(numpy.arange(1000)) + h5f.close() + + def tearDown(self): + del self.aw + os.unlink(self.h5_fname) + os.rmdir(self.tempdir) + super(TestH5pyArrayWidget, self).tearDown() + + def testShow(self): + self.aw.show() + self.qWaitForWindowExposed(self.aw) + + def testReadOnly(self): + """Open H5 dataset in read-only mode, ensure the model is not editable.""" + h5f = h5py.File(self.h5_fname, "r") + a = h5f["my_array"] + # ArrayTableModel relies on following condition + self.assertTrue(a.file.mode == "r") + + self.aw.setArrayData(a, copy=False, editable=True) + + self.assertIsInstance(a, h5py.Dataset) # simple sanity check + # internal representation must be a reference to original data (copy=False) + self.assertIsInstance(self.aw.model._array, h5py.Dataset) + self.assertTrue(self.aw.model._array.file.mode == "r") + + b = self.aw.getData() + self.assertTrue(numpy.array_equal(self.data, b)) + + # model must have detected read-only dataset and disabled editing + self.assertFalse(self.aw.model._editable) + idx = self.aw.model.createIndex(0, 0) + self.assertFalse( + self.aw.model.flags(idx) & qt.Qt.ItemIsEditable) + + # force editing read-only datasets raises IOError + self.assertRaises(IOError, self.aw.model.setData, + idx, 123.4, role=qt.Qt.EditRole) + h5f.close() + + def testReadWrite(self): + h5f = h5py.File(self.h5_fname, "r+") + a = h5f["my_array"] + self.assertTrue(a.file.mode == "r+") + + self.aw.setArrayData(a, copy=False, editable=True) + b = self.aw.getData(copy=False) + self.assertTrue(numpy.array_equal(self.data, b)) + + idx = self.aw.model.createIndex(0, 0) + # model is editable + self.assertTrue( + self.aw.model.flags(idx) & qt.Qt.ItemIsEditable) + h5f.close() + + def testSetData0D(self): + h5f = h5py.File(self.h5_fname, "r+") + a = h5f["my_scalar"] + self.aw.setArrayData(a) + b = self.aw.getData(copy=True) + + self.assertTrue(numpy.array_equal(a, b)) + + h5f.close() + + def testSetData1D(self): + h5f = h5py.File(self.h5_fname, "r+") + a = h5f["my_1D_array"] + self.aw.setArrayData(a) + b = self.aw.getData(copy=True) + + self.assertTrue(numpy.array_equal(a, b)) + + h5f.close() + + def testReferenceReturned(self): + """when setting the data with copy=False and + retrieving it with getData(copy=False), we should recover + the same original object. + + This only works for array with at least 2D. For 1D and 0D + arrays, a view is created at some point, which in the case + of an hdf5 dataset creates a copy.""" + h5f = h5py.File(self.h5_fname, "r+") + + # n-D + a0 = h5f["my_array"] + self.aw.setArrayData(a0, copy=False) + a1 = self.aw.getData(copy=False) + self.assertIs(a0, a1) + + # 1D + b0 = h5f["my_1D_array"] + self.aw.setArrayData(b0, copy=False) + b1 = self.aw.getData(copy=False) + self.assertIs(b0, b1) + + h5f.close() + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestArrayWidget)) + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestH5pyArrayWidget)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/data/test/test_dataviewer.py b/silx/gui/data/test/test_dataviewer.py new file mode 100644 index 0000000..5a0de0b --- /dev/null +++ b/silx/gui/data/test/test_dataviewer.py @@ -0,0 +1,281 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "10/04/2017" + +import os +import tempfile +import unittest +from contextlib import contextmanager + +import numpy +from ..DataViewer import DataViewer +from ..DataViews import DataView +from .. import DataViews + +from silx.gui import qt + +from silx.gui.data.DataViewerFrame import DataViewerFrame +from silx.gui.test.utils import SignalListener +from silx.gui.test.utils import TestCaseQt + +from silx.gui.hdf5.test import _mock + +try: + import h5py +except ImportError: + h5py = None + + +class _DataViewMock(DataView): + """Dummy view to display nothing""" + + def __init__(self, parent): + DataView.__init__(self, parent) + + def axesNames(self, data, info): + return [] + + def createWidget(self, parent): + return qt.QLabel(parent) + + def getDataPriority(self, data, info): + return 0 + + +class AbstractDataViewerTests(TestCaseQt): + + def create_widget(self): + raise NotImplementedError() + + @contextmanager + def h5_temporary_file(self): + # create tmp file + fd, tmp_name = tempfile.mkstemp(suffix=".h5") + os.close(fd) + data = numpy.arange(3 * 3 * 3) + data.shape = 3, 3, 3 + # create h5 data + h5file = h5py.File(tmp_name, "w") + h5file["data"] = data + yield h5file + # clean up + h5file.close() + os.unlink(tmp_name) + + def test_text_data(self): + data_list = ["aaa", int, 8, self] + widget = self.create_widget() + for data in data_list: + widget.setData(data) + self.assertEqual(DataViewer.RAW_MODE, widget.displayMode()) + + def test_plot_1d_data(self): + data = numpy.arange(3 ** 1) + data.shape = [3] * 1 + widget = self.create_widget() + widget.setData(data) + availableModes = set([v.modeId() for v in widget.currentAvailableViews()]) + self.assertEqual(DataViewer.RAW_MODE, widget.displayMode()) + self.assertIn(DataViewer.PLOT1D_MODE, availableModes) + + def test_plot_2d_data(self): + data = numpy.arange(3 ** 2) + data.shape = [3] * 2 + widget = self.create_widget() + widget.setData(data) + availableModes = set([v.modeId() for v in widget.currentAvailableViews()]) + self.assertEqual(DataViewer.RAW_MODE, widget.displayMode()) + self.assertIn(DataViewer.PLOT2D_MODE, availableModes) + + def test_plot_3d_data(self): + data = numpy.arange(3 ** 3) + data.shape = [3] * 3 + widget = self.create_widget() + widget.setData(data) + availableModes = set([v.modeId() for v in widget.currentAvailableViews()]) + try: + import silx.gui.plot3d # noqa + self.assertIn(DataViewer.PLOT3D_MODE, availableModes) + except ImportError: + self.assertIn(DataViewer.STACK_MODE, availableModes) + self.assertEqual(DataViewer.RAW_MODE, widget.displayMode()) + + def test_array_1d_data(self): + data = numpy.array(["aaa"] * (3 ** 1)) + data.shape = [3] * 1 + widget = self.create_widget() + widget.setData(data) + self.assertEqual(DataViewer.RAW_MODE, widget.displayedView().modeId()) + + def test_array_2d_data(self): + data = numpy.array(["aaa"] * (3 ** 2)) + data.shape = [3] * 2 + widget = self.create_widget() + widget.setData(data) + self.assertEqual(DataViewer.RAW_MODE, widget.displayedView().modeId()) + + def test_array_4d_data(self): + data = numpy.array(["aaa"] * (3 ** 4)) + data.shape = [3] * 4 + widget = self.create_widget() + widget.setData(data) + self.assertEqual(DataViewer.RAW_MODE, widget.displayedView().modeId()) + + def test_record_4d_data(self): + data = numpy.zeros(3 ** 4, dtype='3int8, float32, (2,3)float64') + data.shape = [3] * 4 + widget = self.create_widget() + widget.setData(data) + self.assertEqual(DataViewer.RAW_MODE, widget.displayedView().modeId()) + + def test_3d_h5_dataset(self): + if h5py is None: + self.skipTest("h5py library is not available") + with self.h5_temporary_file() as h5file: + dataset = h5file["data"] + widget = self.create_widget() + widget.setData(dataset) + + def test_data_event(self): + listener = SignalListener() + widget = self.create_widget() + widget.dataChanged.connect(listener) + widget.setData(10) + widget.setData(None) + self.assertEquals(listener.callCount(), 2) + + def test_display_mode_event(self): + listener = SignalListener() + widget = self.create_widget() + widget.displayedViewChanged.connect(listener) + widget.setData(10) + widget.setData(None) + modes = [v.modeId() for v in listener.arguments(argumentIndex=0)] + self.assertEquals(modes, [DataViewer.RAW_MODE, DataViewer.EMPTY_MODE]) + listener.clear() + + def test_change_display_mode(self): + data = numpy.arange(10 ** 4) + data.shape = [10] * 4 + widget = self.create_widget() + widget.setData(data) + widget.setDisplayMode(DataViewer.PLOT1D_MODE) + self.assertEquals(widget.displayedView().modeId(), DataViewer.PLOT1D_MODE) + widget.setDisplayMode(DataViewer.PLOT2D_MODE) + self.assertEquals(widget.displayedView().modeId(), DataViewer.PLOT2D_MODE) + widget.setDisplayMode(DataViewer.RAW_MODE) + self.assertEquals(widget.displayedView().modeId(), DataViewer.RAW_MODE) + widget.setDisplayMode(DataViewer.EMPTY_MODE) + self.assertEquals(widget.displayedView().modeId(), DataViewer.EMPTY_MODE) + + def test_create_default_views(self): + widget = self.create_widget() + views = widget.createDefaultViews() + self.assertTrue(len(views) > 0) + + def test_add_view(self): + widget = self.create_widget() + view = _DataViewMock(widget) + widget.addView(view) + self.assertTrue(view in widget.availableViews()) + self.assertTrue(view in widget.currentAvailableViews()) + + def test_remove_view(self): + widget = self.create_widget() + widget.setData("foobar") + view = widget.currentAvailableViews()[0] + widget.removeView(view) + self.assertTrue(view not in widget.availableViews()) + self.assertTrue(view not in widget.currentAvailableViews()) + +class TestDataViewer(AbstractDataViewerTests): + def create_widget(self): + return DataViewer() + + +class TestDataViewerFrame(AbstractDataViewerTests): + def create_widget(self): + return DataViewerFrame() + + +class TestDataView(TestCaseQt): + + def createComplexData(self): + line = [1, 2j, 3+3j, 4] + image = [line, line, line, line] + cube = [image, image, image, image] + data = numpy.array(cube, + dtype=numpy.complex) + return data + + def createDataViewWithData(self, dataViewClass, data): + viewer = dataViewClass(None) + widget = viewer.getWidget() + viewer.setData(data) + return widget + + def testCurveWithComplex(self): + data = self.createComplexData() + dataViewClass = DataViews._Plot1dView + widget = self.createDataViewWithData(dataViewClass, data[0, 0]) + self.qWaitForWindowExposed(widget) + + def testImageWithComplex(self): + data = self.createComplexData() + dataViewClass = DataViews._Plot2dView + widget = self.createDataViewWithData(dataViewClass, data[0]) + self.qWaitForWindowExposed(widget) + + def testCubeWithComplex(self): + self.skipTest("OpenGL widget not yet tested") + try: + import silx.gui.plot3d # noqa + except ImportError: + self.skipTest("OpenGL not available") + data = self.createComplexData() + dataViewClass = DataViews._Plot3dView + widget = self.createDataViewWithData(dataViewClass, data) + self.qWaitForWindowExposed(widget) + + def testImageStackWithComplex(self): + data = self.createComplexData() + dataViewClass = DataViews._StackView + widget = self.createDataViewWithData(dataViewClass, data) + self.qWaitForWindowExposed(widget) + + +def suite(): + test_suite = unittest.TestSuite() + loadTestsFromTestCase = unittest.defaultTestLoader.loadTestsFromTestCase + test_suite.addTest(loadTestsFromTestCase(TestDataViewer)) + test_suite.addTest(loadTestsFromTestCase(TestDataViewerFrame)) + test_suite.addTest(loadTestsFromTestCase(TestDataView)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/data/test/test_numpyaxesselector.py b/silx/gui/data/test/test_numpyaxesselector.py new file mode 100644 index 0000000..cc15f83 --- /dev/null +++ b/silx/gui/data/test/test_numpyaxesselector.py @@ -0,0 +1,152 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "15/12/2016" + +import os +import tempfile +import unittest +from contextlib import contextmanager + +import numpy + +from silx.gui.data.NumpyAxesSelector import NumpyAxesSelector +from silx.gui.test.utils import SignalListener +from silx.gui.test.utils import TestCaseQt + +try: + import h5py +except ImportError: + h5py = None + + +class TestNumpyAxesSelector(TestCaseQt): + + def test_creation(self): + data = numpy.arange(3 * 3 * 3) + data.shape = 3, 3, 3 + widget = NumpyAxesSelector() + widget.setVisible(True) + + def test_none(self): + data = numpy.arange(3 * 3 * 3) + widget = NumpyAxesSelector() + widget.setData(data) + widget.setData(None) + result = widget.selectedData() + self.assertIsNone(result) + + def test_output_samedim(self): + data = numpy.arange(3 * 3 * 3) + data.shape = 3, 3, 3 + expectedResult = data + + widget = NumpyAxesSelector() + widget.setAxisNames(["x", "y", "z"]) + widget.setData(data) + result = widget.selectedData() + self.assertTrue(numpy.array_equal(result, expectedResult)) + + def test_output_lessdim(self): + data = numpy.arange(3 * 3 * 3) + data.shape = 3, 3, 3 + expectedResult = data[0] + + widget = NumpyAxesSelector() + widget.setAxisNames(["y", "x"]) + widget.setData(data) + result = widget.selectedData() + self.assertTrue(numpy.array_equal(result, expectedResult)) + + def test_output_1dim(self): + data = numpy.arange(3 * 3 * 3) + data.shape = 3, 3, 3 + expectedResult = data[0, 0, 0] + + widget = NumpyAxesSelector() + widget.setData(data) + result = widget.selectedData() + self.assertTrue(numpy.array_equal(result, expectedResult)) + + @contextmanager + def h5_temporary_file(self): + # create tmp file + fd, tmp_name = tempfile.mkstemp(suffix=".h5") + os.close(fd) + data = numpy.arange(3 * 3 * 3) + data.shape = 3, 3, 3 + # create h5 data + h5file = h5py.File(tmp_name, "w") + h5file["data"] = data + yield h5file + # clean up + h5file.close() + os.unlink(tmp_name) + + def test_h5py_dataset(self): + if h5py is None: + self.skipTest("h5py library is not available") + with self.h5_temporary_file() as h5file: + dataset = h5file["data"] + expectedResult = dataset[0] + + widget = NumpyAxesSelector() + widget.setData(dataset) + widget.setAxisNames(["y", "x"]) + result = widget.selectedData() + self.assertTrue(numpy.array_equal(result, expectedResult)) + + def test_data_event(self): + data = numpy.arange(3 * 3 * 3) + widget = NumpyAxesSelector() + listener = SignalListener() + widget.dataChanged.connect(listener) + widget.setData(data) + widget.setData(None) + self.assertEqual(listener.callCount(), 2) + + def test_selected_data_event(self): + data = numpy.arange(3 * 3 * 3) + data.shape = 3, 3, 3 + widget = NumpyAxesSelector() + listener = SignalListener() + widget.selectionChanged.connect(listener) + widget.setData(data) + widget.setAxisNames(["x"]) + widget.setData(None) + self.assertEqual(listener.callCount(), 3) + listener.clear() + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestNumpyAxesSelector)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/data/test/test_textformatter.py b/silx/gui/data/test/test_textformatter.py new file mode 100644 index 0000000..f21e033 --- /dev/null +++ b/silx/gui/data/test/test_textformatter.py @@ -0,0 +1,94 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "24/01/2017" + +import unittest + +from silx.gui.test.utils import TestCaseQt +from silx.gui.test.utils import SignalListener +from ..TextFormatter import TextFormatter + + +class TestTextFormatter(TestCaseQt): + + def test_copy(self): + formatter = TextFormatter() + copy = TextFormatter(formatter=formatter) + self.assertIsNot(formatter, copy) + copy.setFloatFormat("%.3f") + self.assertEquals(formatter.integerFormat(), copy.integerFormat()) + self.assertNotEquals(formatter.floatFormat(), copy.floatFormat()) + self.assertEquals(formatter.useQuoteForText(), copy.useQuoteForText()) + self.assertEquals(formatter.imaginaryUnit(), copy.imaginaryUnit()) + + def test_event(self): + listener = SignalListener() + formatter = TextFormatter() + formatter.formatChanged.connect(listener) + formatter.setFloatFormat("%.3f") + formatter.setIntegerFormat("%03i") + formatter.setUseQuoteForText(False) + formatter.setImaginaryUnit("z") + self.assertEquals(listener.callCount(), 4) + + def test_int(self): + formatter = TextFormatter() + formatter.setIntegerFormat("%05i") + result = formatter.toString(512) + self.assertEquals(result, "00512") + + def test_float(self): + formatter = TextFormatter() + formatter.setFloatFormat("%.3f") + result = formatter.toString(1.3) + self.assertEquals(result, "1.300") + + def test_complex(self): + formatter = TextFormatter() + formatter.setFloatFormat("%.1f") + formatter.setImaginaryUnit("i") + result = formatter.toString(1.0 + 5j) + result = result.replace(" ", "") + self.assertEquals(result, "1.0+5.0i") + + def test_string(self): + formatter = TextFormatter() + formatter.setIntegerFormat("%.1f") + formatter.setImaginaryUnit("z") + result = formatter.toString("toto") + self.assertEquals(result, '"toto"') + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestTextFormatter)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/fit/BackgroundWidget.py b/silx/gui/fit/BackgroundWidget.py new file mode 100644 index 0000000..577a8c7 --- /dev/null +++ b/silx/gui/fit/BackgroundWidget.py @@ -0,0 +1,530 @@ +# coding: utf-8 +#/*########################################################################## +# Copyright (C) 2004-2017 V.A. Sole, European Synchrotron Radiation Facility +# +# This file is part of the PyMca X-ray Fluorescence Toolkit developed at +# the ESRF by the Software group. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# #########################################################################*/ +"""This module provides a background configuration widget +:class:`BackgroundWidget` and a corresponding dialog window +:class:`BackgroundDialog`.""" +import sys +import numpy +from silx.gui import qt +from silx.gui.plot import PlotWidget +from silx.math.fit import filters + +__authors__ = ["V.A. Sole", "P. Knobel"] +__license__ = "MIT" +__date__ = "24/01/2017" + + +class HorizontalSpacer(qt.QWidget): + def __init__(self, *args): + qt.QWidget.__init__(self, *args) + self.setSizePolicy(qt.QSizePolicy(qt.QSizePolicy.Expanding, + qt.QSizePolicy.Fixed)) + + +class BackgroundParamWidget(qt.QWidget): + """Background configuration composite widget. + + Strip and snip filters parameters can be adjusted using input widgets. + + Updating the widgets causes :attr:`sigBackgroundParamWidgetSignal` to + be emitted. + """ + sigBackgroundParamWidgetSignal = qt.pyqtSignal(object) + + def __init__(self, parent=None): + qt.QWidget.__init__(self, parent) + + self.mainLayout = qt.QGridLayout(self) + self.mainLayout.setColumnStretch(1, 1) + + # Algorithm choice --------------------------------------------------- + self.algorithmComboLabel = qt.QLabel(self) + self.algorithmComboLabel.setText("Background algorithm") + self.algorithmCombo = qt.QComboBox(self) + self.algorithmCombo.addItem("Strip") + self.algorithmCombo.addItem("Snip") + self.algorithmCombo.activated[int].connect( + self._algorithmComboActivated) + + # Strip parameters --------------------------------------------------- + self.stripWidthLabel = qt.QLabel(self) + self.stripWidthLabel.setText("Strip Width") + + self.stripWidthSpin = qt.QSpinBox(self) + self.stripWidthSpin.setMaximum(100) + self.stripWidthSpin.setMinimum(1) + self.stripWidthSpin.valueChanged[int].connect(self._emitSignal) + + self.stripIterLabel = qt.QLabel(self) + self.stripIterLabel.setText("Strip Iterations") + self.stripIterValue = qt.QLineEdit(self) + validator = qt.QIntValidator(self.stripIterValue) + self.stripIterValue._v = validator + self.stripIterValue.setText("0") + self.stripIterValue.editingFinished[()].connect(self._emitSignal) + self.stripIterValue.setToolTip( + "Number of iterations for strip algorithm.\n" + + "If greater than 999, an 2nd pass of strip filter is " + + "applied to remove artifacts created by first pass.") + + # Snip parameters ---------------------------------------------------- + self.snipWidthLabel = qt.QLabel(self) + self.snipWidthLabel.setText("Snip Width") + + self.snipWidthSpin = qt.QSpinBox(self) + self.snipWidthSpin.setMaximum(300) + self.snipWidthSpin.setMinimum(0) + self.snipWidthSpin.valueChanged[int].connect(self._emitSignal) + + + # Smoothing parameters ----------------------------------------------- + self.smoothingFlagCheck = qt.QCheckBox(self) + self.smoothingFlagCheck.setText("Smoothing Width (Savitsky-Golay)") + self.smoothingFlagCheck.toggled.connect(self._smoothingToggled) + + self.smoothingSpin = qt.QSpinBox(self) + self.smoothingSpin.setMinimum(3) + #self.smoothingSpin.setMaximum(40) + self.smoothingSpin.setSingleStep(2) + self.smoothingSpin.valueChanged[int].connect(self._emitSignal) + + # Anchors ------------------------------------------------------------ + + self.anchorsGroup = qt.QWidget(self) + anchorsLayout = qt.QHBoxLayout(self.anchorsGroup) + anchorsLayout.setSpacing(2) + anchorsLayout.setContentsMargins(0, 0, 0, 0) + + self.anchorsFlagCheck = qt.QCheckBox(self.anchorsGroup) + self.anchorsFlagCheck.setText("Use anchors") + self.anchorsFlagCheck.setToolTip( + "Define X coordinates of points that must remain fixed") + self.anchorsFlagCheck.stateChanged[int].connect( + self._anchorsToggled) + anchorsLayout.addWidget(self.anchorsFlagCheck) + + maxnchannel = 16384 * 4 # Fixme ? + self.anchorsList = [] + num_anchors = 4 + for i in range(num_anchors): + anchorSpin = qt.QSpinBox(self.anchorsGroup) + anchorSpin.setMinimum(0) + anchorSpin.setMaximum(maxnchannel) + anchorSpin.valueChanged[int].connect(self._emitSignal) + anchorsLayout.addWidget(anchorSpin) + self.anchorsList.append(anchorSpin) + + # Layout ------------------------------------------------------------ + self.mainLayout.addWidget(self.algorithmComboLabel, 0, 0) + self.mainLayout.addWidget(self.algorithmCombo, 0, 2) + self.mainLayout.addWidget(self.stripWidthLabel, 1, 0) + self.mainLayout.addWidget(self.stripWidthSpin, 1, 2) + self.mainLayout.addWidget(self.stripIterLabel, 2, 0) + self.mainLayout.addWidget(self.stripIterValue, 2, 2) + self.mainLayout.addWidget(self.snipWidthLabel, 3, 0) + self.mainLayout.addWidget(self.snipWidthSpin, 3, 2) + self.mainLayout.addWidget(self.smoothingFlagCheck, 4, 0) + self.mainLayout.addWidget(self.smoothingSpin, 4, 2) + self.mainLayout.addWidget(self.anchorsGroup, 5, 0, 1, 4) + + # Initialize interface ----------------------------------------------- + self._setAlgorithm("strip") + self.smoothingFlagCheck.setChecked(False) + self._smoothingToggled(is_checked=False) + self.anchorsFlagCheck.setChecked(False) + self._anchorsToggled(is_checked=False) + + def _algorithmComboActivated(self, algorithm_index): + self._setAlgorithm("strip" if algorithm_index == 0 else "snip") + + def _setAlgorithm(self, algorithm): + """Enable/disable snip and snip input widgets, depending on the + chosen algorithm. + :param algorithm: "snip" or "strip" + """ + if algorithm not in ["strip", "snip"]: + raise ValueError( + "Unknown background filter algorithm %s" % algorithm) + + self.algorithm = algorithm + self.stripWidthSpin.setEnabled(algorithm == "strip") + self.stripIterValue.setEnabled(algorithm == "strip") + self.snipWidthSpin.setEnabled(algorithm == "snip") + + def _smoothingToggled(self, is_checked): + """Enable/disable smoothing input widgets, emit dictionary""" + self.smoothingSpin.setEnabled(is_checked) + self._emitSignal() + + def _anchorsToggled(self, is_checked): + """Enable/disable all spin widgets defining anchor X coordinates, + emit signal. + """ + for anchor_spin in self.anchorsList: + anchor_spin.setEnabled(is_checked) + self._emitSignal() + + def setParameters(self, ddict): + """Set values for all input widgets. + + :param dict ddict: Input dictionary, must have the same + keys as the dictionary output by :meth:`getParameters` + """ + if "algorithm" in ddict: + self._setAlgorithm(ddict["algorithm"]) + + if "SnipWidth" in ddict: + self.snipWidthSpin.setValue(int(ddict["SnipWidth"])) + + if "StripWidth" in ddict: + self.stripWidthSpin.setValue(int(ddict["StripWidth"])) + + if "StripIterations" in ddict: + self.stripIterValue.setText("%d" % int(ddict["StripIterations"])) + + if "SmoothingFlag" in ddict: + self.smoothingFlagCheck.setChecked(bool(ddict["SmoothingFlag"])) + + if "SmoothingWidth" in ddict: + self.smoothingSpin.setValue(int(ddict["SmoothingWidth"])) + + if "AnchorsFlag" in ddict: + self.anchorsFlagCheck.setChecked(bool(ddict["AnchorsFlag"])) + + if "AnchorsList" in ddict: + anchorslist = ddict["AnchorsList"] + if anchorslist in [None, 'None']: + anchorslist = [] + for spin in self.anchorsList: + spin.setValue(0) + + i = 0 + for value in anchorslist: + self.anchorsList[i].setValue(int(value)) + i += 1 + + def getParameters(self): + """Return dictionary of parameters defined in the GUI + + The returned dictionary contains following values: + + - *algorithm*: *"strip"* or *"snip"* + - *StripWidth*: width of strip iterator + - *StripIterations*: number of iterations + - *StripThreshold*: curvature parameter (currently fixed to 1.0) + - *SnipWidth*: width of snip algorithm + - *SmoothingFlag*: flag to enable/disable smoothing + - *SmoothingWidth*: width of Savitsky-Golay smoothing filter + - *AnchorsFlag*: flag to enable/disable anchors + - *AnchorsList*: list of anchors (X coordinates of fixed values) + """ + stripitertext = self.stripIterValue.text() + stripiter = int(stripitertext) if len(stripitertext) else 0 + + return {"algorithm": self.algorithm, + "StripThreshold": 1.0, + "SnipWidth": self.snipWidthSpin.value(), + "StripIterations": stripiter, + "StripWidth": self.stripWidthSpin.value(), + "SmoothingFlag": self.smoothingFlagCheck.isChecked(), + "SmoothingWidth": self.smoothingSpin.value(), + "AnchorsFlag": self.anchorsFlagCheck.isChecked(), + "AnchorsList": [spin.value() for spin in self.anchorsList]} + + def _emitSignal(self, dummy=None): + self.sigBackgroundParamWidgetSignal.emit( + {'event': 'ParametersChanged', + 'parameters': self.getParameters()}) + + +class BackgroundWidget(qt.QWidget): + """Background configuration widget, with a :class:`PlotWindow`. + + Strip and snip filters parameters can be adjusted using input widgets, + and the computed backgrounds are plotted next to the original data to + show the result.""" + def __init__(self, parent=None): + qt.QWidget.__init__(self, parent) + self.setWindowTitle("Strip and SNIP Configuration Window") + self.mainLayout = qt.QVBoxLayout(self) + self.mainLayout.setContentsMargins(0, 0, 0, 0) + self.mainLayout.setSpacing(2) + self.parametersWidget = BackgroundParamWidget(self) + self.graphWidget = PlotWidget(parent=self) + self.mainLayout.addWidget(self.parametersWidget) + self.mainLayout.addWidget(self.graphWidget) + self._x = None + self._y = None + self.parametersWidget.sigBackgroundParamWidgetSignal.connect(self._slot) + + def getParameters(self): + """Return dictionary of parameters defined in the GUI + + The returned dictionary contains following values: + + - *algorithm*: *"strip"* or *"snip"* + - *StripWidth*: width of strip iterator + - *StripIterations*: number of iterations + - *StripThreshold*: strip curvature (currently fixed to 1.0) + - *SnipWidth*: width of snip algorithm + - *SmoothingFlag*: flag to enable/disable smoothing + - *SmoothingWidth*: width of Savitsky-Golay smoothing filter + - *AnchorsFlag*: flag to enable/disable anchors + - *AnchorsList*: list of anchors (X coordinates of fixed values) + """ + return self.parametersWidget.getParameters() + + def setParameters(self, ddict): + """Set values for all input widgets. + + :param dict ddict: Input dictionary, must have the same + keys as the dictionary output by :meth:`getParameters` + """ + return self.parametersWidget.setParameters(ddict) + + def setData(self, x, y, xmin=None, xmax=None): + """Set data for the original curve, and _update strip and snip + curves accordingly. + + :param x: Array or sequence of curve abscissa values + :param y: Array or sequence of curve ordinate values + :param xmin: Min value to be displayed on the X axis + :param xmax: Max value to be displayed on the X axis + """ + self._x = x + self._y = y + self._xmin = xmin + self._xmax = xmax + self._update(resetzoom=True) + + def _slot(self, ddict): + self._update() + + def _update(self, resetzoom=False): + """Compute strip and snip backgrounds, update the curves + """ + if self._y is None: + return + + pars = self.getParameters() + + # smoothed data + y = numpy.ravel(numpy.array(self._y)).astype(numpy.float) + if pars["SmoothingFlag"]: + ysmooth = filters.savitsky_golay(y, pars['SmoothingWidth']) + f = [0.25, 0.5, 0.25] + ysmooth[1:-1] = numpy.convolve(ysmooth, f, mode=0) + ysmooth[0] = 0.5 * (ysmooth[0] + ysmooth[1]) + ysmooth[-1] = 0.5 * (ysmooth[-1] + ysmooth[-2]) + else: + ysmooth = y + + + # loop for anchors + x = self._x + niter = pars['StripIterations'] + anchors_indices = [] + if pars['AnchorsFlag'] and pars['AnchorsList'] is not None: + ravelled = x + for channel in pars['AnchorsList']: + if channel <= ravelled[0]: + continue + index = numpy.nonzero(ravelled >= channel)[0] + if len(index): + index = min(index) + if index > 0: + anchors_indices.append(index) + + stripBackground = filters.strip(ysmooth, + w=pars['StripWidth'], + niterations=niter, + factor=pars['StripThreshold'], + anchors=anchors_indices) + + if niter >= 1000: + # final smoothing + stripBackground = filters.strip(stripBackground, + w=1, + niterations=50*pars['StripWidth'], + factor=pars['StripThreshold'], + anchors=anchors_indices) + + if len(anchors_indices) == 0: + anchors_indices = [0, len(ysmooth)-1] + anchors_indices.sort() + snipBackground = 0.0 * ysmooth + lastAnchor = 0 + for anchor in anchors_indices: + if (anchor > lastAnchor) and (anchor < len(ysmooth)): + snipBackground[lastAnchor:anchor] =\ + filters.snip1d(ysmooth[lastAnchor:anchor], + pars['SnipWidth']) + lastAnchor = anchor + if lastAnchor < len(ysmooth): + snipBackground[lastAnchor:] =\ + filters.snip1d(ysmooth[lastAnchor:], + pars['SnipWidth']) + + self.graphWidget.addCurve(x, y, + legend='Input Data', + replace=True, + resetzoom=resetzoom) + self.graphWidget.addCurve(x, stripBackground, + legend='Strip Background', + resetzoom=False) + self.graphWidget.addCurve(x, snipBackground, + legend='SNIP Background', + resetzoom=False) + if self._xmin is not None and self._xmax is not None: + self.graphWidget.setGraphXLimits(xmin=self._xmin, xmax=self._xmax) + + +class BackgroundDialog(qt.QDialog): + """QDialog window featuring a :class:`BackgroundWidget`""" + def __init__(self, parent=None): + qt.QDialog.__init__(self, parent) + self.setWindowTitle("Strip and Snip Configuration Window") + self.mainLayout = qt.QVBoxLayout(self) + self.mainLayout.setContentsMargins(0, 0, 0, 0) + self.mainLayout.setSpacing(2) + self.parametersWidget = BackgroundWidget(self) + self.mainLayout.addWidget(self.parametersWidget) + hbox = qt.QWidget(self) + hboxLayout = qt.QHBoxLayout(hbox) + hboxLayout.setContentsMargins(0, 0, 0, 0) + hboxLayout.setSpacing(2) + self.okButton = qt.QPushButton(hbox) + self.okButton.setText("OK") + self.okButton.setAutoDefault(False) + self.dismissButton = qt.QPushButton(hbox) + self.dismissButton.setText("Cancel") + self.dismissButton.setAutoDefault(False) + hboxLayout.addWidget(HorizontalSpacer(hbox)) + hboxLayout.addWidget(self.okButton) + hboxLayout.addWidget(self.dismissButton) + self.mainLayout.addWidget(hbox) + self.dismissButton.clicked.connect(self.reject) + self.okButton.clicked.connect(self.accept) + + self.output = {} + """Configuration dictionary containing following fields: + + - *SmoothingFlag* + - *SmoothingWidth* + - *StripWidth* + - *StripIterations* + - *StripThreshold* + - *SnipWidth* + - *AnchorsFlag* + - *AnchorsList* + """ + + # self.parametersWidget.parametersWidget.sigBackgroundParamWidgetSignal.connect(self.updateOutput) + + # def updateOutput(self, ddict): + # self.output = ddict + + def accept(self): + """Update :attr:`output`, then call :meth:`QDialog.accept` + """ + self.output = self.getParameters() + super(BackgroundDialog, self).accept() + + def sizeHint(self): + return qt.QSize(int(1.5*qt.QDialog.sizeHint(self).width()), + qt.QDialog.sizeHint(self).height()) + + def setData(self, x, y, xmin=None, xmax=None): + """See :meth:`BackgroundWidget.setData`""" + return self.parametersWidget.setData(x, y, xmin, xmax) + + def getParameters(self): + """See :meth:`BackgroundWidget.getParameters`""" + return self.parametersWidget.getParameters() + + def setParameters(self, ddict): + """See :meth:`BackgroundWidget.setParameters`""" + return self.parametersWidget.setParameters(ddict) + + def setDefault(self, ddict): + """Alias for :meth:`setParameters`""" + return self.setParameters(ddict) + + +def getBgDialog(parent=None, default=None, modal=True): + """Instantiate and return a bg configuration dialog, adapted + for configuring standard background theories from + :mod:`silx.math.fit.bgtheories`. + + :return: Instance of :class:`BackgroundDialog` + """ + bgd = BackgroundDialog(parent=parent) + # apply default to newly added pages + bgd.setParameters(default) + + return bgd + + +def main(): + # synthetic data + from silx.math.fit.functions import sum_gauss + + x = numpy.arange(5000) + # (height1, center1, fwhm1, ...) 5 peaks + params1 = (50, 500, 100, + 20, 2000, 200, + 50, 2250, 100, + 40, 3000, 75, + 23, 4000, 150) + y0 = sum_gauss(x, *params1) + + # random values between [-1;1] + noise = 2 * numpy.random.random(5000) - 1 + # make it +- 5% + noise *= 0.05 + + # 2 gaussians with very large fwhm, as background signal + actual_bg = sum_gauss(x, 15, 3500, 3000, 5, 1000, 1500) + + # Add 5% random noise to gaussians and add background + y = y0 + numpy.average(y0) * noise + actual_bg + + # Open widget + a = qt.QApplication(sys.argv) + a.lastWindowClosed.connect(a.quit) + + def mySlot(ddict): + print(ddict) + + w = BackgroundDialog() + w.parametersWidget.parametersWidget.sigBackgroundParamWidgetSignal.connect(mySlot) + w.setData(x, y) + w.exec_() + #a.exec_() + +if __name__ == "__main__": + main() diff --git a/silx/gui/fit/FitConfig.py b/silx/gui/fit/FitConfig.py new file mode 100644 index 0000000..70b6fbe --- /dev/null +++ b/silx/gui/fit/FitConfig.py @@ -0,0 +1,540 @@ +# coding: utf-8 +# /*########################################################################## +# Copyright (C) 2004-2016 V.A. Sole, European Synchrotron Radiation Facility +# +# This file is part of the PyMca X-ray Fluorescence Toolkit developed at +# the ESRF by the Software group. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ######################################################################### */ +"""This module defines widgets used to build a fit configuration dialog. +The resulting dialog widget outputs a dictionary of configuration parameters. +""" +from silx.gui import qt + +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "30/11/2016" + + +class TabsDialog(qt.QDialog): + """Dialog widget containing a QTabWidget :attr:`tabWidget` + and a buttons: + + # - buttonHelp + - buttonDefaults + - buttonOk + - buttonCancel + + This dialog defines a __len__ returning the number of tabs, + and an __iter__ method yielding the tab widgets. + """ + def __init__(self, parent=None): + qt.QDialog.__init__(self, parent) + self.tabWidget = qt.QTabWidget(self) + + layout = qt.QVBoxLayout(self) + layout.addWidget(self.tabWidget) + + layout2 = qt.QHBoxLayout(None) + + # self.buttonHelp = qt.QPushButton(self) + # self.buttonHelp.setText("Help") + # layout2.addWidget(self.buttonHelp) + + self.buttonDefault = qt.QPushButton(self) + self.buttonDefault.setText("Default") + layout2.addWidget(self.buttonDefault) + + spacer = qt.QSpacerItem(20, 20, + qt.QSizePolicy.Expanding, + qt.QSizePolicy.Minimum) + layout2.addItem(spacer) + + self.buttonOk = qt.QPushButton(self) + self.buttonOk.setText("OK") + layout2.addWidget(self.buttonOk) + + self.buttonCancel = qt.QPushButton(self) + self.buttonCancel.setText("Cancel") + layout2.addWidget(self.buttonCancel) + + layout.addLayout(layout2) + + self.buttonOk.clicked.connect(self.accept) + self.buttonCancel.clicked.connect(self.reject) + + def __len__(self): + """Return number of tabs""" + return self.tabWidget.count() + + def __iter__(self): + """Return the next tab widget in :attr:`tabWidget` every + time this method is called. + + :return: Tab widget + :rtype: QWidget + """ + for widget_index in range(len(self)): + yield self.tabWidget.widget(widget_index) + + def addTab(self, page, label): + """Add a new tab + + :param page: Content of new page. Must be a widget with + a get() method returning a dictionary. + :param str label: Tab label + """ + self.tabWidget.addTab(page, label) + + def getTabLabels(self): + """ + Return a list of all tab labels in :attr:`tabWidget` + """ + return [self.tabWidget.tabText(i) for i in range(len(self))] + + +class TabsDialogData(TabsDialog): + """This dialog adds a data attribute to :class:`TabsDialog`. + + Data input in widgets, such as text entries or checkboxes, is stored in an + attribute :attr:`output` when the user clicks the OK button. + + A default dictionary can be supplied when this dialog is initialized, to + be used as default data for :attr:`output`. + """ + def __init__(self, parent=None, modal=True, default=None): + """ + + :param parent: Parent :class:`QWidget` + :param modal: If `True`, dialog is modal, meaning this dialog remains + in front of it's parent window and disables it until the user is + done interacting with the dialog + :param default: Default dictionary, used to initialize and reset + :attr:`output`. + """ + TabsDialog.__init__(self, parent) + self.setModal(modal) + self.setWindowTitle("Fit configuration") + + self.output = {} + + self.default = {} if default is None else default + + self.buttonDefault.clicked.connect(self.setDefault) + # self.keyPressEvent(qt.Qt.Key_Enter). + + def keyPressEvent(self, event): + """Redefining this method to ignore Enter key + (for some reason it activates buttonDefault callback which + resets all widgets) + """ + if event.key() in [qt.Qt.Key_Enter, qt.Qt.Key_Return]: + return + TabsDialog.keyPressEvent(self, event) + + def accept(self): + """When *OK* is clicked, update :attr:`output` with data from + various widgets + """ + self.output.update(self.default) + + # loop over all tab widgets (uses TabsDialog.__iter__) + for tabWidget in self: + self.output.update(tabWidget.get()) + + # avoid pathological None cases + for key in self.output.keys(): + if self.output[key] is None: + if key in self.default: + self.output[key] = self.default[key] + super(TabsDialogData, self).accept() + + def reject(self): + """When the *Cancel* button is clicked, reinitialize :attr:`output` + and quit + """ + self.setDefault() + super(TabsDialogData, self).reject() + + def setDefault(self, newdefault=None): + """Reinitialize :attr:`output` with :attr:`default` or with + new dictionary ``newdefault`` if provided. + Call :meth:`setDefault` for each tab widget, if available. + """ + self.output = {} + if newdefault is None: + newdefault = self.default + else: + self.default = newdefault + self.output.update(newdefault) + + for tabWidget in self: + if hasattr(tabWidget, "setDefault"): + tabWidget.setDefault(self.output) + + +class ConstraintsPage(qt.QGroupBox): + """Checkable QGroupBox widget filled with QCheckBox widgets, + to configure the fit estimation for standard fit theories. + """ + def __init__(self, parent=None, title="Set constraints"): + super(ConstraintsPage, self).__init__(parent) + self.setTitle(title) + self.setToolTip("Disable 'Set constraints' to remove all " + + "constraints on all fit parameters") + self.setCheckable(True) + + layout = qt.QVBoxLayout(self) + self.setLayout(layout) + + self.positiveHeightCB = qt.QCheckBox("Force positive height/area", self) + self.positiveHeightCB.setToolTip("Fit must find positive peaks") + layout.addWidget(self.positiveHeightCB) + + self.positionInIntervalCB = qt.QCheckBox("Force position in interval", self) + self.positionInIntervalCB.setToolTip( + "Fit must position peak within X limits") + layout.addWidget(self.positionInIntervalCB) + + self.positiveFwhmCB = qt.QCheckBox("Force positive FWHM", self) + self.positiveFwhmCB.setToolTip("Fit must find a positive FWHM") + layout.addWidget(self.positiveFwhmCB) + + self.sameFwhmCB = qt.QCheckBox("Force same FWHM for all peaks", self) + self.sameFwhmCB.setToolTip("Fit must find same FWHM for all peaks") + layout.addWidget(self.sameFwhmCB) + + self.quotedEtaCB = qt.QCheckBox("Force Eta between 0 and 1", self) + self.quotedEtaCB.setToolTip( + "Fit must find Eta between 0 and 1 for pseudo-Voigt function") + layout.addWidget(self.quotedEtaCB) + + layout.addStretch() + + self.setDefault() + + def setDefault(self, default_dict=None): + """Set default state for all widgets. + + :param default_dict: If a default config dictionary is provided as + a parameter, its values are used as default state.""" + if default_dict is None: + default_dict = {} + # this one uses reverse logic: if checked, NoConstraintsFlag must be False + self.setChecked( + not default_dict.get('NoConstraintsFlag', False)) + self.positiveHeightCB.setChecked( + default_dict.get('PositiveHeightAreaFlag', True)) + self.positionInIntervalCB.setChecked( + default_dict.get('QuotedPositionFlag', False)) + self.positiveFwhmCB.setChecked( + default_dict.get('PositiveFwhmFlag', True)) + self.sameFwhmCB.setChecked( + default_dict.get('SameFwhmFlag', False)) + self.quotedEtaCB.setChecked( + default_dict.get('QuotedEtaFlag', False)) + + def get(self): + """Return a dictionary of constraint flags, to be processed by the + :meth:`configure` method of the selected fit theory.""" + ddict = { + 'NoConstraintsFlag': not self.isChecked(), + 'PositiveHeightAreaFlag': self.positiveHeightCB.isChecked(), + 'QuotedPositionFlag': self.positionInIntervalCB.isChecked(), + 'PositiveFwhmFlag': self.positiveFwhmCB.isChecked(), + 'SameFwhmFlag': self.sameFwhmCB.isChecked(), + 'QuotedEtaFlag': self.quotedEtaCB.isChecked(), + } + return ddict + + +class SearchPage(qt.QWidget): + def __init__(self, parent=None): + super(SearchPage, self).__init__(parent) + layout = qt.QVBoxLayout(self) + + self.manualFwhmGB = qt.QGroupBox("Define FWHM manually", self) + self.manualFwhmGB.setCheckable(True) + self.manualFwhmGB.setToolTip( + "If disabled, the FWHM parameter used for peak search is " + + "estimated based on the highest peak in the data") + layout.addWidget(self.manualFwhmGB) + # ------------ GroupBox fwhm-------------------------- + layout2 = qt.QHBoxLayout(self.manualFwhmGB) + self.manualFwhmGB.setLayout(layout2) + + label = qt.QLabel("Fwhm Points", self.manualFwhmGB) + layout2.addWidget(label) + + self.fwhmPointsSpin = qt.QSpinBox(self.manualFwhmGB) + self.fwhmPointsSpin.setRange(0, 999999) + self.fwhmPointsSpin.setToolTip("Typical peak fwhm (number of data points)") + layout2.addWidget(self.fwhmPointsSpin) + # ---------------------------------------------------- + + self.manualScalingGB = qt.QGroupBox("Define scaling manually", self) + self.manualScalingGB.setCheckable(True) + self.manualScalingGB.setToolTip( + "If disabled, the Y scaling used for peak search is " + + "estimated automatically") + layout.addWidget(self.manualScalingGB) + # ------------ GroupBox scaling----------------------- + layout3 = qt.QHBoxLayout(self.manualScalingGB) + self.manualScalingGB.setLayout(layout3) + + label = qt.QLabel("Y Scaling", self.manualScalingGB) + layout3.addWidget(label) + + self.yScalingEntry = qt.QLineEdit(self.manualScalingGB) + self.yScalingEntry.setToolTip( + "Data values will be multiplied by this value prior to peak" + + " search") + self.yScalingEntry.setValidator(qt.QDoubleValidator()) + layout3.addWidget(self.yScalingEntry) + # ---------------------------------------------------- + + # ------------------- grid layout -------------------- + containerWidget = qt.QWidget(self) + layout4 = qt.QHBoxLayout(containerWidget) + containerWidget.setLayout(layout4) + + label = qt.QLabel("Sensitivity", containerWidget) + layout4.addWidget(label) + + self.sensitivityEntry = qt.QLineEdit(containerWidget) + self.sensitivityEntry.setToolTip( + "Peak search sensitivity threshold, expressed as a multiple " + + "of the standard deviation of the noise.\nMinimum value is 1 " + + "(to be detected, peak must be higher than the estimated noise)") + sensivalidator = qt.QDoubleValidator() + sensivalidator.setBottom(1.0) + self.sensitivityEntry.setValidator(sensivalidator) + layout4.addWidget(self.sensitivityEntry) + # ---------------------------------------------------- + layout.addWidget(containerWidget) + + self.forcePeakPresenceCB = qt.QCheckBox("Force peak presence", self) + self.forcePeakPresenceCB.setToolTip( + "If peak search algorithm is unsuccessful, place one peak " + + "at the maximum of the curve") + layout.addWidget(self.forcePeakPresenceCB) + + layout.addStretch() + + self.setDefault() + + def setDefault(self, default_dict=None): + """Set default values for all widgets. + + :param default_dict: If a default config dictionary is provided as + a parameter, its values are used as default values.""" + if default_dict is None: + default_dict = {} + self.manualFwhmGB.setChecked( + not default_dict.get('AutoFwhm', True)) + self.fwhmPointsSpin.setValue( + default_dict.get('FwhmPoints', 8)) + self.sensitivityEntry.setText( + str(default_dict.get('Sensitivity', 1.0))) + self.manualScalingGB.setChecked( + not default_dict.get('AutoScaling', False)) + self.yScalingEntry.setText( + str(default_dict.get('Yscaling', 1.0))) + self.forcePeakPresenceCB.setChecked( + default_dict.get('ForcePeakPresence', False)) + + def get(self): + """Return a dictionary of peak search parameters, to be processed by + the :meth:`configure` method of the selected fit theory.""" + ddict = { + 'AutoFwhm': not self.manualFwhmGB.isChecked(), + 'FwhmPoints': self.fwhmPointsSpin.value(), + 'Sensitivity': safe_float(self.sensitivityEntry.text()), + 'AutoScaling': not self.manualScalingGB.isChecked(), + 'Yscaling': safe_float(self.yScalingEntry.text()), + 'ForcePeakPresence': self.forcePeakPresenceCB.isChecked() + } + return ddict + + +class BackgroundPage(qt.QGroupBox): + """Background subtraction configuration, specific to fittheories + estimation functions.""" + def __init__(self, parent=None, + title="Subtract strip background prior to estimation"): + super(BackgroundPage, self).__init__(parent) + self.setTitle(title) + self.setCheckable(True) + self.setToolTip( + "The strip algorithm strips away peaks to compute the " + + "background signal.\nAt each iteration, a sample is compared " + + "to the average of the two samples at a given distance in both" + + " directions,\n and if its value is higher than the average," + "it is replaced by the average.") + + layout = qt.QGridLayout(self) + self.setLayout(layout) + + for i, label_text in enumerate( + ["Strip width (in samples)", + "Number of iterations", + "Strip threshold factor"]): + label = qt.QLabel(label_text) + layout.addWidget(label, i, 0) + + self.stripWidthSpin = qt.QSpinBox(self) + self.stripWidthSpin.setToolTip( + "Width, in number of samples, of the strip operator") + self.stripWidthSpin.setRange(1, 999999) + + layout.addWidget(self.stripWidthSpin, 0, 1) + + self.numIterationsSpin = qt.QSpinBox(self) + self.numIterationsSpin.setToolTip( + "Number of iterations of the strip algorithm") + self.numIterationsSpin.setRange(1, 999999) + layout.addWidget(self.numIterationsSpin, 1, 1) + + self.thresholdFactorEntry = qt.QLineEdit(self) + self.thresholdFactorEntry.setToolTip( + "Factor used by the strip algorithm to decide whether a sample" + + "value should be stripped.\nThe value must be higher than the " + + "average of the 2 samples at +- w times this factor.\n") + self.thresholdFactorEntry.setValidator(qt.QDoubleValidator()) + layout.addWidget(self.thresholdFactorEntry, 2, 1) + + self.smoothStripGB = qt.QGroupBox("Apply smoothing prior to strip", self) + self.smoothStripGB.setCheckable(True) + self.smoothStripGB.setToolTip( + "Apply a smoothing before subtracting strip background" + + " in fit and estimate processes") + smoothlayout = qt.QHBoxLayout(self.smoothStripGB) + label = qt.QLabel("Smoothing width (Savitsky-Golay)") + smoothlayout.addWidget(label) + self.smoothingWidthSpin = qt.QSpinBox(self) + self.smoothingWidthSpin.setToolTip( + "Width parameter for Savitsky-Golay smoothing (number of samples, must be odd)") + self.smoothingWidthSpin.setRange(3, 101) + self.smoothingWidthSpin.setSingleStep(2) + smoothlayout.addWidget(self.smoothingWidthSpin) + + layout.addWidget(self.smoothStripGB, 3, 0, 1, 2) + + layout.setRowStretch(4, 1) + + self.setDefault() + + def setDefault(self, default_dict=None): + """Set default values for all widgets. + + :param default_dict: If a default config dictionary is provided as + a parameter, its values are used as default values.""" + if default_dict is None: + default_dict = {} + + self.setChecked( + default_dict.get('StripBackgroundFlag', True)) + + self.stripWidthSpin.setValue( + default_dict.get('StripWidth', 2)) + self.numIterationsSpin.setValue( + default_dict.get('StripIterations', 5000)) + self.thresholdFactorEntry.setText( + str(default_dict.get('StripThreshold', 1.0))) + self.smoothStripGB.setChecked( + default_dict.get('SmoothingFlag', False)) + self.smoothingWidthSpin.setValue( + default_dict.get('SmoothingWidth', 3)) + + def get(self): + """Return a dictionary of background subtraction parameters, to be + processed by the :meth:`configure` method of the selected fit theory. + """ + ddict = { + 'StripBackgroundFlag': self.isChecked(), + 'StripWidth': self.stripWidthSpin.value(), + 'StripIterations': self.numIterationsSpin.value(), + 'StripThreshold': safe_float(self.thresholdFactorEntry.text()), + 'SmoothingFlag': self.smoothStripGB.isChecked(), + 'SmoothingWidth': self.smoothingWidthSpin.value() + } + return ddict + + +def safe_float(string_, default=1.0): + """Convert a string into a float. + If the conversion fails, return the default value. + """ + try: + ret = float(string_) + except ValueError: + return default + else: + return ret + + +def safe_int(string_, default=1): + """Convert a string into a integer. + If the conversion fails, return the default value. + """ + try: + ret = int(float(string_)) + except ValueError: + return default + else: + return ret + + +def getFitConfigDialog(parent=None, default=None, modal=True): + """Instantiate and return a fit configuration dialog, adapted + for configuring standard fit theories from + :mod:`silx.math.fit.fittheories`. + + :return: Instance of :class:`TabsDialogData` with 3 tabs: + :class:`ConstraintsPage`, :class:`SearchPage` and + :class:`BackgroundPage` + """ + tdd = TabsDialogData(parent=parent, default=default) + tdd.addTab(ConstraintsPage(), label="Constraints") + tdd.addTab(SearchPage(), label="Peak search") + tdd.addTab(BackgroundPage(), label="Background") + # apply default to newly added pages + tdd.setDefault() + + return tdd + + +def main(): + a = qt.QApplication([]) + + mw = qt.QMainWindow() + mw.show() + + tdd = getFitConfigDialog(mw, default={"a": 1}) + tdd.show() + tdd.exec_() + print("TabsDialogData result: ", tdd.result()) + print("TabsDialogData output: ", tdd.output) + + a.exec_() + +if __name__ == "__main__": + main() diff --git a/silx/gui/fit/FitWidget.py b/silx/gui/fit/FitWidget.py new file mode 100644 index 0000000..a5c3cfd --- /dev/null +++ b/silx/gui/fit/FitWidget.py @@ -0,0 +1,727 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# This file is part of the PyMca X-ray Fluorescence Toolkit developed at +# the ESRF by the Software group. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ######################################################################### */ +"""This module provides a widget designed to configure and run a fitting +process with constraints on parameters. + +The main class is :class:`FitWidget`. It relies on +:mod:`silx.math.fit.fitmanager`, which relies on :func:`silx.math.fit.leastsq`. + +The user can choose between functions before running the fit. These function can +be user defined, or by default are loaded from +:mod:`silx.math.fit.fittheories`. +""" + +__authors__ = ["V.A. Sole", "P. Knobel"] +__license__ = "MIT" +__date__ = "15/02/2017" + +import logging +import sys +import traceback +import warnings + +from silx.math.fit import fittheories +from silx.math.fit import fitmanager, functions +from silx.gui import qt +from .FitWidgets import (FitActionsButtons, FitStatusLines, + FitConfigWidget, ParametersTab) +from .FitConfig import getFitConfigDialog +from .BackgroundWidget import getBgDialog, BackgroundDialog + +QTVERSION = qt.qVersion() +DEBUG = 0 +_logger = logging.getLogger(__name__) + + +__authors__ = ["V.A. Sole", "P. Knobel"] +__license__ = "MIT" +__date__ = "30/11/2016" + + +class FitWidget(qt.QWidget): + """This widget can be used to configure, run and display results of a + fitting process. + + The standard steps for using this widget is to initialize it, then load + the data to be fitted. + + Optionally, you can also load user defined fit theories. If you skip this + step, a series of default fit functions will be presented (gaussian-like + functions), and you can later load your custom fit theories from an + external file using the GUI. + + A fit theory is a fit function and its associated features: + + - estimation function, + - list of parameter names + - numerical derivative algorithm + - configuration widget + + Once the widget is up and running, the user may select a fit theory and a + background theory, change configuration parameters specific to the theory + run the estimation, set constraints on parameters and run the actual fit. + + The results are displayed in a table. + """ + sigFitWidgetSignal = qt.Signal(object) + """This signal is emitted by the estimation and fit methods. + It carries a dictionary with two items: + + - *event*: one of the following strings + + - *EstimateStarted*, + - *FitStarted* + - *EstimateFinished*, + - *FitFinished* + - *EstimateFailed* + - *FitFailed* + + - *data*: None, or fit/estimate results (see documentation for + :attr:`silx.math.fit.fitmanager.FitManager.fit_results`) + """ + + def __init__(self, parent=None, title=None, fitmngr=None, + enableconfig=True, enablestatus=True, enablebuttons=True): + """ + + :param parent: Parent widget + :param title: Window title + :param fitmngr: User defined instance of + :class:`silx.math.fit.fitmanager.FitManager`, or ``None`` + :param enableconfig: If ``True``, activate widgets to modify the fit + configuration (select between several fit functions or background + functions, apply global constraints, peak search parameters…) + :param enablestatus: If ``True``, add a fit status widget, to display + a message when fit estimation is available and when fit results + are available, as well as a measure of the fit error. + :param enablebuttons: If ``True``, add buttons to run estimation and + fitting. + """ + if title is None: + title = "FitWidget" + qt.QWidget.__init__(self, parent) + + self.setWindowTitle(title) + layout = qt.QVBoxLayout(self) + + self.fitmanager = self._setFitManager(fitmngr) + """Instance of :class:`FitManager`. + This is the underlying data model of this FitWidget. + + If no custom theories are defined, the default ones from + :mod:`silx.math.fit.fittheories` are imported. + """ + + # reference fitmanager.configure method for direct access + self.configure = self.fitmanager.configure + self.fitconfig = self.fitmanager.fitconfig + + self.configdialogs = {} + """This dictionary defines the fit configuration widgets + associated with the fit theories in :attr:`fitmanager.theories` + + Keys must correspond to existing theory names, i.e. existing keys + in :attr:`fitmanager.theories`. + + Values must be instances of QDialog widgets with an additional + *output* attribute, a dictionary storing configuration parameters + interpreted by the corresponding fit theory. + + The dialog can also define a *setDefault* method to initialize the + widget values with values in a dictionary passed as a parameter. + This will be executed first. + + In case the widget does not actually inherit :class:`QDialog`, it + must at least implement the following methods (executed in this + particular order): + + - :meth:`show`: should cause the widget to become visible to the + user) + - :meth:`exec_`: should run while the user is interacting with the + widget, interrupting the rest of the program. It should + typically end (*return*) when the user clicks an *OK* + or a *Cancel* button. + - :meth:`result`: must return ``True`` if the new configuration in + attribute :attr:`output` is to be accepted (user clicked *OK*), + or return ``False`` if :attr:`output` is to be rejected (user + clicked *Cancel*) + + To associate a custom configuration widget with a fit theory, use + :meth:`associateConfigDialog`. E.g.:: + + fw = FitWidget() + my_config_widget = MyGaussianConfigWidget(parent=fw) + fw.associateConfigDialog(theory_name="Gaussians", + config_widget=my_config_widget) + """ + + self.bgconfigdialogs = {} + """Same as :attr:`configdialogs`, except that the widget is associated + with a background theory in :attr:`fitmanager.bgtheories`""" + + self._associateDefaultConfigDialogs() + + self.guiConfig = None + """Configuration widget at the top of FitWidget, to select + fit function, background function, and open an advanced + configuration dialog.""" + + self.guiParameters = ParametersTab(self) + """Table widget for display of fit parameters and constraints""" + + if enableconfig: + self.guiConfig = FitConfigWidget(self) + """Function selector and configuration widget""" + + self.guiConfig.FunConfigureButton.clicked.connect( + self.__funConfigureGuiSlot) + self.guiConfig.BgConfigureButton.clicked.connect( + self.__bgConfigureGuiSlot) + + self.guiConfig.WeightCheckBox.setChecked( + self.fitconfig.get("WeightFlag", False)) + self.guiConfig.WeightCheckBox.stateChanged[int].connect(self.weightEvent) + + self.guiConfig.BkgComBox.activated[str].connect(self.bkgEvent) + self.guiConfig.FunComBox.activated[str].connect(self.funEvent) + self._populateFunctions() + + layout.addWidget(self.guiConfig) + + layout.addWidget(self.guiParameters) + + if enablestatus: + self.guistatus = FitStatusLines(self) + """Status bar""" + layout.addWidget(self.guistatus) + + if enablebuttons: + self.guibuttons = FitActionsButtons(self) + """Widget with estimate, start fit and dismiss buttons""" + self.guibuttons.EstimateButton.clicked.connect(self.estimate) + self.guibuttons.StartFitButton.clicked.connect(self.startFit) + self.guibuttons.DismissButton.clicked.connect(self.dismiss) + layout.addWidget(self.guibuttons) + + def _setFitManager(self, fitinstance): + """Initialize a :class:`FitManager` instance, to be assigned to + :attr:`fitmanager`, or use a custom FitManager instance. + + :param fitinstance: Existing instance of FitManager, possibly + customized by the user, or None to load a default instance.""" + if isinstance(fitinstance, fitmanager.FitManager): + # customized + fitmngr = fitinstance + else: + # initialize default instance + fitmngr = fitmanager.FitManager() + + # initialize the default fitting functions in case + # none is present + if not len(fitmngr.theories): + fitmngr.loadtheories(fittheories) + + return fitmngr + + def _associateDefaultConfigDialogs(self): + """Fill :attr:`bgconfigdialogs` and :attr:`configdialogs` by calling + :meth:`associateConfigDialog` with default config dialog widgets. + """ + # associate silx.gui.fit.FitConfig with all theories + # Users can later associate their own custom dialogs to + # replace the default. + configdialog = getFitConfigDialog(parent=self, + default=self.fitconfig) + for theory in self.fitmanager.theories: + self.associateConfigDialog(theory, configdialog) + for bgtheory in self.fitmanager.bgtheories: + self.associateConfigDialog(bgtheory, configdialog, + theory_is_background=True) + + # associate silx.gui.fit.BackgroundWidget with Strip and Snip + bgdialog = getBgDialog(parent=self, + default=self.fitconfig) + for bgtheory in ["Strip", "Snip"]: + if bgtheory in self.fitmanager.bgtheories: + self.associateConfigDialog(bgtheory, bgdialog, + theory_is_background=True) + + def _populateFunctions(self): + """Fill combo-boxes with fit theories and background theories + loaded by :attr:`fitmanager`. + Run :meth:`fitmanager.configure` to ensure the custom configuration + of the selected theory has been loaded into :attr:`fitconfig`""" + for theory_name in self.fitmanager.bgtheories: + self.guiConfig.BkgComBox.addItem(theory_name) + self.guiConfig.BkgComBox.setItemData( + self.guiConfig.BkgComBox.findText(theory_name), + self.fitmanager.bgtheories[theory_name].description, + qt.Qt.ToolTipRole) + + for theory_name in self.fitmanager.theories: + self.guiConfig.FunComBox.addItem(theory_name) + self.guiConfig.FunComBox.setItemData( + self.guiConfig.FunComBox.findText(theory_name), + self.fitmanager.theories[theory_name].description, + qt.Qt.ToolTipRole) + + # - activate selected fit theory (if any) + # - activate selected bg theory (if any) + configuration = self.fitmanager.configure() + if self.fitmanager.selectedtheory is None: + # take the first one by default + self.guiConfig.FunComBox.setCurrentIndex(1) + self.funEvent(list(self.fitmanager.theories.keys())[0]) + else: + idx = list(self.fitmanager.theories).index(self.fitmanager.selectedtheory) + self.guiConfig.FunComBox.setCurrentIndex(idx + 1) + self.funEvent(self.fitmanager.selectedtheory) + + if self.fitmanager.selectedbg is None: + self.guiConfig.BkgComBox.setCurrentIndex(1) + self.bkgEvent(list(self.fitmanager.bgtheories.keys())[0]) + else: + idx = list(self.fitmanager.bgtheories).index(self.fitmanager.selectedbg) + self.guiConfig.BkgComBox.setCurrentIndex(idx + 1) + self.bkgEvent(self.fitmanager.selectedbg) + + configuration.update(self.configure()) + + def setdata(self, x, y, sigmay=None, xmin=None, xmax=None): + warnings.warn("Method renamed to setData", + DeprecationWarning) + self.setData(x, y, sigmay, xmin, xmax) + + def setData(self, x, y, sigmay=None, xmin=None, xmax=None): + """Set data to be fitted. + + :param x: Abscissa data. If ``None``, :attr:`xdata`` is set to + ``numpy.array([0.0, 1.0, 2.0, ..., len(y)-1])`` + :type x: Sequence or numpy array or None + :param y: The dependant data ``y = f(x)``. ``y`` must have the same + shape as ``x`` if ``x`` is not ``None``. + :type y: Sequence or numpy array or None + :param sigmay: The uncertainties in the ``ydata`` array. These are + used as weights in the least-squares problem. + If ``None``, the uncertainties are assumed to be 1. + :type sigmay: Sequence or numpy array or None + :param xmin: Lower value of x values to use for fitting + :param xmax: Upper value of x values to use for fitting + """ + self.fitmanager.setdata(x=x, y=y, sigmay=sigmay, + xmin=xmin, xmax=xmax) + for config_dialog in self.bgconfigdialogs.values(): + if isinstance(config_dialog, BackgroundDialog): + config_dialog.setData(x, y, xmin=xmin, xmax=xmax) + + def associateConfigDialog(self, theory_name, config_widget, + theory_is_background=False): + """Associate an instance of custom configuration dialog widget to + a fit theory or to a background theory. + + This adds or modifies an item in the correspondence table + :attr:`configdialogs` or :attr:`bgconfigdialogs`. + + :param str theory_name: Name of fit theory. This must be a key of dict + :attr:`fitmanager.theories` + :param config_widget: Custom configuration widget. See documentation + for :attr:`configdialogs` + :param bool theory_is_background: If flag is *True*, add dialog to + :attr:`bgconfigdialogs` rather than :attr:`configdialogs` + (default). + :raise: KeyError if parameter ``theory_name`` does not match an + existing fit theory or background theory in :attr:`fitmanager`. + :raise: AttributeError if the widget does not implement the mandatory + methods (*show*, *exec_*, *result*, *setDefault*) or the mandatory + attribute (*output*). + """ + theories = self.fitmanager.bgtheories if theory_is_background else\ + self.fitmanager.theories + + if theory_name not in theories: + raise KeyError("%s does not match an existing fitmanager theory") + + if config_widget is not None: + for mandatory_attr in ["show", "exec_", "result", "output"]: + if not hasattr(config_widget, mandatory_attr): + raise AttributeError( + "Custom configuration widget must define " + + "attribute or method " + mandatory_attr) + + if theory_is_background: + self.bgconfigdialogs[theory_name] = config_widget + else: + self.configdialogs[theory_name] = config_widget + + def _emitSignal(self, ddict): + """Emit pyqtSignal after estimation completed + (``ddict = {'event': 'EstimateFinished', 'data': fit_results}``) + and after fit completed + (``ddict = {'event': 'FitFinished', 'data': fit_results}``)""" + self.sigFitWidgetSignal.emit(ddict) + + def __funConfigureGuiSlot(self): + """Open an advanced configuration dialog widget""" + self.__configureGui(dialog_type="function") + + def __bgConfigureGuiSlot(self): + """Open an advanced configuration dialog widget""" + self.__configureGui(dialog_type="background") + + def __configureGui(self, newconfiguration=None, dialog_type="function"): + """Open an advanced configuration dialog widget to get a configuration + dictionary, or use a supplied configuration dictionary. Call + :meth:`configure` with this dictionary as a parameter. Update the gui + accordingly. Reinitialize the fit results in the table and in + :attr:`fitmanager`. + + :param newconfiguration: User supplied configuration dictionary. If ``None``, + open a dialog widget that returns a dictionary.""" + configuration = self.configure() + # get new dictionary + if newconfiguration is None: + newconfiguration = self.configureDialog(configuration, dialog_type) + # update configuration + configuration.update(self.configure(**newconfiguration)) + # set fit function theory + try: + i = 1 + \ + list(self.fitmanager.theories.keys()).index( + self.fitmanager.selectedtheory) + self.guiConfig.FunComBox.setCurrentIndex(i) + self.funEvent(self.fitmanager.selectedtheory) + except ValueError: + _logger.error("Function not in list %s", + self.fitmanager.selectedtheory) + self.funEvent(list(self.fitmanager.theories.keys())[0]) + # current background + try: + i = 1 + \ + list(self.fitmanager.bgtheories.keys()).index( + self.fitmanager.selectedbg) + self.guiConfig.BkgComBox.setCurrentIndex(i) + self.bkgEvent(self.fitmanager.selectedbg) + except ValueError: + _logger.error("Background not in list %s", + self.fitmanager.selectedbg) + self.bkgEvent(list(self.fitmanager.bgtheories.keys())[0]) + + # update the Gui + self.__initialParameters() + + def configureDialog(self, oldconfiguration, dialog_type="function"): + """Display a dialog, allowing the user to define fit configuration + parameters. + + By default, a common dialog is used for all fit theories. But if the + defined a custom dialog using :meth:`associateConfigDialog`, it is + used instead. + + :param dict oldconfiguration: Dictionary containing previous configuration + :param str dialog_type: "function" or "background" + :return: User defined parameters in a dictionary + """ + newconfiguration = {} + newconfiguration.update(oldconfiguration) + + if dialog_type == "function": + theory = self.fitmanager.selectedtheory + configdialog = self.configdialogs[theory] + elif dialog_type == "background": + theory = self.fitmanager.selectedbg + configdialog = self.bgconfigdialogs[theory] + + # this should only happen if a user specifically associates None + # with a theory, to have no configuration option + if configdialog is None: + return {} + + # update state of configdialog before showing it + if hasattr(configdialog, "setDefault"): + configdialog.setDefault(newconfiguration) + configdialog.show() + configdialog.exec_() + if configdialog.result(): + newconfiguration.update(configdialog.output) + + return newconfiguration + + def estimate(self): + """Run parameter estimation function then emit + :attr:`sigFitWidgetSignal` with a dictionary containing a status + message and a list of fit parameters estimations + in the format defined in + :attr:`silx.math.fit.fitmanager.FitManager.fit_results` + + The emitted dictionary has an *"event"* key that can have + following values: + + - *'EstimateStarted'* + - *'EstimateFailed'* + - *'EstimateFinished'* + """ + try: + theory_name = self.fitmanager.selectedtheory + estimation_function = self.fitmanager.theories[theory_name].estimate + if estimation_function is not None: + ddict = {'event': 'EstimateStarted', + 'data': None} + self._emitSignal(ddict) + self.fitmanager.estimate(callback=self.fitStatus) + else: + msg = qt.QMessageBox(self) + msg.setIcon(qt.QMessageBox.Information) + text = "Function does not define a way to estimate\n" + text += "the initial parameters. Please, fill them\n" + text += "yourself in the table and press Start Fit\n" + msg.setText(text) + msg.setWindowTitle('FitWidget Message') + msg.exec_() + return + except: # noqa (we want to catch and report all errors) + msg = qt.QMessageBox(self) + msg.setIcon(qt.QMessageBox.Critical) + msg.setText("Error on estimate: %s" % traceback.format_exc()) + msg.exec_() + ddict = { + 'event': 'EstimateFailed', + 'data': None} + self._emitSignal(ddict) + return + + self.guiParameters.fillFromFit( + self.fitmanager.fit_results, view='Fit') + self.guiParameters.removeAllViews(keep='Fit') + ddict = { + 'event': 'EstimateFinished', + 'data': self.fitmanager.fit_results} + self._emitSignal(ddict) + + def startfit(self): + warnings.warn("Method renamed to startFit", + DeprecationWarning) + self.startFit() + + def startFit(self): + """Run fit, then emit :attr:`sigFitWidgetSignal` with a dictionary + containing a status message and a list of fit + parameters results in the format defined in + :attr:`silx.math.fit.fitmanager.FitManager.fit_results` + + The emitted dictionary has an *"event"* key that can have + following values: + + - *'FitStarted'* + - *'FitFailed'* + - *'FitFinished'* + """ + self.fitmanager.fit_results = self.guiParameters.getFitResults() + try: + ddict = {'event': 'FitStarted', + 'data': None} + self._emitSignal(ddict) + self.fitmanager.runfit(callback=self.fitStatus) + except: # noqa (we want to catch and report all errors) + msg = qt.QMessageBox(self) + msg.setIcon(qt.QMessageBox.Critical) + msg.setText("Error on Fit: %s" % traceback.format_exc()) + msg.exec_() + ddict = { + 'event': 'FitFailed', + 'data': None + } + self._emitSignal(ddict) + return + + self.guiParameters.fillFromFit( + self.fitmanager.fit_results, view='Fit') + self.guiParameters.removeAllViews(keep='Fit') + ddict = { + 'event': 'FitFinished', + 'data': self.fitmanager.fit_results + } + self._emitSignal(ddict) + return + + def bkgEvent(self, bgtheory): + """Select background theory, then reinitialize parameters""" + bgtheory = str(bgtheory) + if bgtheory in self.fitmanager.bgtheories: + self.fitmanager.setbackground(bgtheory) + else: + functionsfile = qt.QFileDialog.getOpenFileName( + self, "Select python module with your function(s)", "", + "Python Files (*.py);;All Files (*)") + + if len(functionsfile): + try: + self.fitmanager.loadbgtheories(functionsfile) + except ImportError: + qt.QMessageBox.critical(self, "ERROR", + "Function not imported") + return + else: + # empty the ComboBox + while self.guiConfig.BkgComBox.count() > 1: + self.guiConfig.BkgComBox.removeItem(1) + # and fill it again + for key in self.fitmanager.bgtheories: + self.guiConfig.BkgComBox.addItem(str(key)) + + i = 1 + \ + list(self.fitmanager.bgtheories.keys()).index( + self.fitmanager.selectedbg) + self.guiConfig.BkgComBox.setCurrentIndex(i) + self.__initialParameters() + + def funEvent(self, theoryname): + """Select a fit theory to be used for fitting. If this theory exists + in :attr:`fitmanager`, use it. Then, reinitialize table. + + :param theoryname: Name of the fit theory to use for fitting. If this theory + exists in :attr:`fitmanager`, use it. Else, open a file dialog to open + a custom fit function definition file with + :meth:`fitmanager.loadtheories`. + """ + theoryname = str(theoryname) + if theoryname in self.fitmanager.theories: + self.fitmanager.settheory(theoryname) + else: + # open a load file dialog + functionsfile = qt.QFileDialog.getOpenFileName( + self, "Select python module with your function(s)", "", + "Python Files (*.py);;All Files (*)") + + if len(functionsfile): + try: + self.fitmanager.loadtheories(functionsfile) + except ImportError: + qt.QMessageBox.critical(self, "ERROR", + "Function not imported") + return + else: + # empty the ComboBox + while self.guiConfig.FunComBox.count() > 1: + self.guiConfig.FunComBox.removeItem(1) + # and fill it again + for key in self.fitmanager.theories: + self.guiConfig.FunComBox.addItem(str(key)) + + i = 1 + \ + list(self.fitmanager.theories.keys()).index( + self.fitmanager.selectedtheory) + self.guiConfig.FunComBox.setCurrentIndex(i) + self.__initialParameters() + + def weightEvent(self, flag): + """This is called when WeightCheckBox is clicked, to configure the + *WeightFlag* field in :attr:`fitmanager.fitconfig` and set weights + in the least-square problem.""" + self.configure(WeightFlag=flag) + if flag: + self.fitmanager.enableweight() + else: + # set weights back to 1 + self.fitmanager.disableweight() + + def __initialParameters(self): + """Fill the fit parameters names with names of the parameters of + the selected background theory and the selected fit theory. + Initialize :attr:`fitmanager.fit_results` with these names, and + initialize the table with them. This creates a view called "Fit" + in :attr:`guiParameters`""" + self.fitmanager.parameter_names = [] + self.fitmanager.fit_results = [] + for pname in self.fitmanager.bgtheories[self.fitmanager.selectedbg].parameters: + self.fitmanager.parameter_names.append(pname) + self.fitmanager.fit_results.append({'name': pname, + 'estimation': 0, + 'group': 0, + 'code': 'FREE', + 'cons1': 0, + 'cons2': 0, + 'fitresult': 0.0, + 'sigma': 0.0, + 'xmin': None, + 'xmax': None}) + if self.fitmanager.selectedtheory is not None: + theory = self.fitmanager.selectedtheory + for pname in self.fitmanager.theories[theory].parameters: + self.fitmanager.parameter_names.append(pname + "1") + self.fitmanager.fit_results.append({'name': pname + "1", + 'estimation': 0, + 'group': 1, + 'code': 'FREE', + 'cons1': 0, + 'cons2': 0, + 'fitresult': 0.0, + 'sigma': 0.0, + 'xmin': None, + 'xmax': None}) + + self.guiParameters.fillFromFit( + self.fitmanager.fit_results, view='Fit') + + def fitStatus(self, data): + """Set *status* and *chisq* in status bar""" + if 'chisq' in data: + if data['chisq'] is None: + self.guistatus.ChisqLine.setText(" ") + else: + chisq = data['chisq'] + self.guistatus.ChisqLine.setText("%6.2f" % chisq) + + if 'status' in data: + status = data['status'] + self.guistatus.StatusLine.setText(str(status)) + + def dismiss(self): + """Close FitWidget""" + self.close() + + +if __name__ == "__main__": + import numpy + + x = numpy.arange(1500).astype(numpy.float) + constant_bg = 3.14 + + p = [1000, 100., 30.0, + 500, 300., 25., + 1700, 500., 35., + 750, 700., 30.0, + 1234, 900., 29.5, + 302, 1100., 30.5, + 75, 1300., 21.] + y = functions.sum_gauss(x, *p) + constant_bg + + a = qt.QApplication(sys.argv) + w = FitWidget() + w.setData(x=x, y=y) + w.show() + a.exec_() diff --git a/silx/gui/fit/FitWidgets.py b/silx/gui/fit/FitWidgets.py new file mode 100644 index 0000000..408666b --- /dev/null +++ b/silx/gui/fit/FitWidgets.py @@ -0,0 +1,559 @@ +# coding: utf-8 +# /*########################################################################## +# Copyright (C) 2004-2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ######################################################################### */ +"""Collection of widgets used to build +:class:`silx.gui.fit.FitWidget.FitWidget`""" + +from collections import OrderedDict + +from silx.gui import qt +from silx.gui.fit.Parameters import Parameters + +QTVERSION = qt.qVersion() + +__authors__ = ["V.A. Sole", "P. Knobel"] +__license__ = "MIT" +__date__ = "13/10/2016" + + +class FitActionsButtons(qt.QWidget): + """Widget with 3 ``QPushButton``: + + The buttons can be accessed as public attributes:: + + - ``EstimateButton`` + - ``StartFitButton`` + - ``DismissButton`` + + You will typically need to access these attributes to connect the buttons + to actions. For instance, if you have 3 functions ``estimate``, + ``runfit`` and ``dismiss``, you can connect them like this:: + + >>> fit_actions_buttons = FitActionsButtons() + >>> fit_actions_buttons.EstimateButton.clicked.connect(estimate) + >>> fit_actions_buttons.StartFitButton.clicked.connect(runfit) + >>> fit_actions_buttons.DismissButton.clicked.connect(dismiss) + + """ + + def __init__(self, parent=None): + qt.QWidget.__init__(self, parent) + + self.resize(234, 53) + + grid_layout = qt.QGridLayout(self) + grid_layout.setContentsMargins(11, 11, 11, 11) + grid_layout.setSpacing(6) + layout = qt.QHBoxLayout(None) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(6) + + self.EstimateButton = qt.QPushButton(self) + self.EstimateButton.setText("Estimate") + layout.addWidget(self.EstimateButton) + spacer = qt.QSpacerItem(20, 20, + qt.QSizePolicy.Expanding, + qt.QSizePolicy.Minimum) + layout.addItem(spacer) + + self.StartFitButton = qt.QPushButton(self) + self.StartFitButton.setText("Start Fit") + layout.addWidget(self.StartFitButton) + spacer_2 = qt.QSpacerItem(20, 20, + qt.QSizePolicy.Expanding, + qt.QSizePolicy.Minimum) + layout.addItem(spacer_2) + + self.DismissButton = qt.QPushButton(self) + self.DismissButton.setText("Dismiss") + layout.addWidget(self.DismissButton) + + grid_layout.addLayout(layout, 0, 0) + + +class FitStatusLines(qt.QWidget): + """Widget with 2 greyed out write-only ``QLineEdit``. + + These text widgets can be accessed as public attributes:: + + - ``StatusLine`` + - ``ChisqLine`` + + You will typically need to access these widgets to update the displayed + text:: + + >>> fit_status_lines = FitStatusLines() + >>> fit_status_lines.StatusLine.setText("Ready") + >>> fit_status_lines.ChisqLine.setText("%6.2f" % 0.01) + + """ + + def __init__(self, parent=None): + qt.QWidget.__init__(self, parent) + + self.resize(535, 47) + + layout = qt.QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(6) + + self.StatusLabel = qt.QLabel(self) + self.StatusLabel.setText("Status:") + layout.addWidget(self.StatusLabel) + + self.StatusLine = qt.QLineEdit(self) + self.StatusLine.setText("Ready") + self.StatusLine.setReadOnly(1) + layout.addWidget(self.StatusLine) + + self.ChisqLabel = qt.QLabel(self) + self.ChisqLabel.setText("Reduced chisq:") + layout.addWidget(self.ChisqLabel) + + self.ChisqLine = qt.QLineEdit(self) + self.ChisqLine.setMaximumSize(qt.QSize(16000, 32767)) + self.ChisqLine.setText("") + self.ChisqLine.setReadOnly(1) + layout.addWidget(self.ChisqLine) + + +class FitConfigWidget(qt.QWidget): + """Widget whose purpose is to select a fit theory and a background + theory, load a new fit theory definition file and provide + a "Configure" button to open an advanced configuration dialog. + + This is used in :class:`silx.gui.fit.FitWidget.FitWidget`, to offer + an interface to quickly modify the main parameters prior to running a fit: + + - select a fitting function through :attr:`FunComBox` + - select a background function through :attr:`BkgComBox` + - open a dialog for modifying advanced parameters through + :attr:`FunConfigureButton` + """ + def __init__(self, parent=None): + qt.QWidget.__init__(self, parent) + + self.setWindowTitle("FitConfigGUI") + + layout = qt.QGridLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(6) + + self.FunLabel = qt.QLabel(self) + self.FunLabel.setText("Function") + layout.addWidget(self.FunLabel, 0, 0) + + self.FunComBox = qt.QComboBox(self) + self.FunComBox.addItem("Add Function(s)") + self.FunComBox.setItemData(self.FunComBox.findText("Add Function(s)"), + "Load fit theories from a file", + qt.Qt.ToolTipRole) + layout.addWidget(self.FunComBox, 0, 1) + + self.BkgLabel = qt.QLabel(self) + self.BkgLabel.setText("Background") + layout.addWidget(self.BkgLabel, 1, 0) + + self.BkgComBox = qt.QComboBox(self) + self.BkgComBox.addItem("Add Background(s)") + self.BkgComBox.setItemData(self.BkgComBox.findText("Add Background(s)"), + "Load background theories from a file", + qt.Qt.ToolTipRole) + layout.addWidget(self.BkgComBox, 1, 1) + + self.FunConfigureButton = qt.QPushButton(self) + self.FunConfigureButton.setText("Configure") + self.FunConfigureButton.setToolTip( + "Open a configuration dialog for the selected function") + layout.addWidget(self.FunConfigureButton, 0, 2) + + self.BgConfigureButton = qt.QPushButton(self) + self.BgConfigureButton.setText("Configure") + self.BgConfigureButton.setToolTip( + "Open a configuration dialog for the selected background") + layout.addWidget(self.BgConfigureButton, 1, 2) + + self.WeightCheckBox = qt.QCheckBox(self) + self.WeightCheckBox.setText("Weighted fit") + self.WeightCheckBox.setToolTip( + "Enable usage of weights in the least-square problem.\n Use" + + " the uncertainties (sigma) if provided, else use sqrt(y).") + + layout.addWidget(self.WeightCheckBox, 0, 3, 2, 1) + + layout.setColumnStretch(4, 1) + + +class ParametersTab(qt.QTabWidget): + """This widget provides tabs to display and modify fit parameters. Each + tab contains a table with fit data such as parameter names, estimated + values, fit constraints, and final fit results. + + The usual way to initialize the table is to fill it with the fit + parameters from a :class:`silx.math.fit.fitmanager.FitManager` object, after + the estimation process or after the final fit. + + In the following example we use a :class:`ParametersTab` to display the + results of two separate fits:: + + from silx.math.fit import fittheories + from silx.math.fit import fitmanager + from silx.math.fit import functions + from silx.gui import qt + import numpy + + a = qt.QApplication([]) + + # Create synthetic data + x = numpy.arange(1000) + y1 = functions.sum_gauss(x, 100, 400, 100) + + fit = fitmanager.FitManager(x=x, y=y1) + + fitfuns = fittheories.FitTheories() + fit.addtheory(theory="Gaussian", + function=functions.sum_gauss, + parameters=("height", "peak center", "fwhm"), + estimate=fitfuns.estimate_height_position_fwhm) + fit.settheory('Gaussian') + fit.configure(PositiveFwhmFlag=True, + PositiveHeightAreaFlag=True, + AutoFwhm=True,) + + # Fit + fit.estimate() + fit.runfit() + + # Show first fit result in a tab in our widget + w = ParametersTab() + w.show() + w.fillFromFit(fit.fit_results, view='Gaussians') + + # new synthetic data + y2 = functions.sum_splitgauss(x, + 100, 400, 100, 40, + 10, 600, 50, 500, + 80, 850, 10, 50) + fit.setData(x=x, y=y2) + + # Define new theory + fit.addtheory(theory="Asymetric gaussian", + function=functions.sum_splitgauss, + parameters=("height", "peak center", "left fwhm", "right fwhm"), + estimate=fitfuns.estimate_splitgauss) + fit.settheory('Asymetric gaussian') + + # Fit + fit.estimate() + fit.runfit() + + # Show first fit result in another tab in our widget + w.fillFromFit(fit.fit_results, view='Asymetric gaussians') + a.exec_() + + """ + + def __init__(self, parent=None, name="FitParameters"): + """ + + :param parent: Parent widget + :param name: Widget title + """ + qt.QTabWidget.__init__(self, parent) + self.setWindowTitle(name) + self.setContentsMargins(0, 0, 0, 0) + + self.views = OrderedDict() + """Dictionary of views. Keys are view names, + items are :class:`Parameters` widgets""" + + self.latest_view = None + """Name of latest view""" + + # the widgets/tables themselves + self.tables = {} + """Dictionary of :class:`silx.gui.fit.parameters.Parameters` objects. + These objects store fit results + """ + + self.setContentsMargins(10, 10, 10, 10) + + def setView(self, view=None, fitresults=None): + """Add or update a table. Fill it with data from a fit + + :param view: Tab name to be added or updated. If ``None``, use the + latest view. + :param fitresults: Fit data to be added to the table + :raise: KeyError if no view name specified and no latest view + available. + """ + if view is None: + if self.latest_view is not None: + view = self.latest_view + else: + raise KeyError( + "No view available. You must specify a view" + + " name the first time you call this method." + ) + + if view in self.tables.keys(): + table = self.tables[view] + else: + # create the parameters instance + self.tables[view] = Parameters(self) + table = self.tables[view] + self.views[view] = table + self.addTab(table, str(view)) + + if fitresults is not None: + table.fillFromFit(fitresults) + + self.setCurrentWidget(self.views[view]) + self.latest_view = view + + def renameView(self, oldname=None, newname=None): + """Rename a view (tab) + + :param oldname: Name of the view to be renamed + :param newname: New name of the view""" + error = 1 + if newname is not None: + if newname not in self.views.keys(): + if oldname in self.views.keys(): + parameterlist = self.tables[oldname].getFitResults() + self.setView(view=newname, fitresults=parameterlist) + self.removeView(oldname) + error = 0 + return error + + def fillFromFit(self, fitparameterslist, view=None): + """Update a view with data from a fit (alias for :meth:`setView`) + + :param view: Tab name to be added or updated (default: latest view) + :param fitparameterslist: Fit data to be added to the table + """ + self.setView(view=view, fitresults=fitparameterslist) + + def getFitResults(self, name=None): + """Call :meth:`getFitResults` for the + :class:`silx.gui.fit.parameters.Parameters` corresponding to the + latest table or to the named table (if ``name`` is not + ``None``). This return a list of dictionaries in the format used by + :class:`silx.math.fit.fitmanager.FitManager` to store fit parameter + results. + + :param name: View name. + """ + if name is None: + name = self.latest_view + return self.tables[name].getFitResults() + + def removeView(self, name): + """Remove a view by name. + + :param name: View name. + """ + if name in self.views: + index = self.indexOf(self.tables[name]) + self.removeTab(index) + index = self.indexOf(self.views[name]) + self.removeTab(index) + del self.tables[name] + del self.views[name] + + def removeAllViews(self, keep=None): + """Remove all views, except the one specified (argument + ``keep``) + + :param keep: Name of the view to be kept.""" + for view in self.tables: + if view != keep: + self.removeView(view) + + def getHtmlText(self, name=None): + """Return the table data as HTML + + :param name: View name.""" + if name is None: + name = self.latest_view + table = self.tables[name] + lemon = ("#%x%x%x" % (255, 250, 205)).upper() + hcolor = ("#%x%x%x" % (230, 240, 249)).upper() + text = "" + text += "<nobr>" + text += "<table>" + text += "<tr>" + ncols = table.columnCount() + for l in range(ncols): + text += ('<td align="left" bgcolor="%s"><b>' % hcolor) + if QTVERSION < '4.0.0': + text += (str(table.horizontalHeader().label(l))) + else: + text += (str(table.horizontalHeaderItem(l).text())) + text += "</b></td>" + text += "</tr>" + nrows = table.rowCount() + for r in range(nrows): + text += "<tr>" + item = table.item(r, 0) + newtext = "" + if item is not None: + newtext = str(item.text()) + if len(newtext): + color = "white" + b = "<b>" + else: + b = "" + color = lemon + try: + # MyQTable item has color defined + cc = table.item(r, 0).color + cc = ("#%x%x%x" % (cc.red(), cc.green(), cc.blue())).upper() + color = cc + except: + pass + for c in range(ncols): + item = table.item(r, c) + newtext = "" + if item is not None: + newtext = str(item.text()) + if len(newtext): + finalcolor = color + else: + finalcolor = "white" + if c < 2: + text += ('<td align="left" bgcolor="%s">%s' % + (finalcolor, b)) + else: + text += ('<td align="right" bgcolor="%s">%s' % + (finalcolor, b)) + text += newtext + if len(b): + text += "</td>" + else: + text += "</b></td>" + item = table.item(r, 0) + newtext = "" + if item is not None: + newtext = str(item.text()) + if len(newtext): + text += "</b>" + text += "</tr>" + text += "\n" + text += "</table>" + text += "</nobr>" + return text + + def getText(self, name=None): + """Return the table data as CSV formatted text, using tabulation + characters as separators. + + :param name: View name.""" + if name is None: + name = self.latest_view + table = self.tables[name] + text = "" + ncols = table.columnCount() + for l in range(ncols): + text += (str(table.horizontalHeaderItem(l).text())) + "\t" + text += "\n" + nrows = table.rowCount() + for r in range(nrows): + for c in range(ncols): + newtext = "" + if c != 4: + item = table.item(r, c) + if item is not None: + newtext = str(item.text()) + else: + item = table.cellWidget(r, c) + if item is not None: + newtext = str(item.currentText()) + text += newtext + "\t" + text += "\n" + text += "\n" + return text + + +def test(): + from silx.math.fit import fittheories + from silx.math.fit import fitmanager + from silx.math.fit import functions + from silx.gui.plot.PlotWindow import PlotWindow + import numpy + + a = qt.QApplication([]) + + x = numpy.arange(1000) + y1 = functions.sum_gauss(x, 100, 400, 100) + + fit = fitmanager.FitManager(x=x, y=y1) + + fitfuns = fittheories.FitTheories() + fit.addtheory(name="Gaussian", + function=functions.sum_gauss, + parameters=("height", "peak center", "fwhm"), + estimate=fitfuns.estimate_height_position_fwhm) + fit.settheory('Gaussian') + fit.configure(PositiveFwhmFlag=True, + PositiveHeightAreaFlag=True, + AutoFwhm=True,) + + # Fit + fit.estimate() + fit.runfit() + + w = ParametersTab() + w.show() + w.fillFromFit(fit.fit_results, view='Gaussians') + + y2 = functions.sum_splitgauss(x, + 100, 400, 100, 40, + 10, 600, 50, 500, + 80, 850, 10, 50) + fit.setdata(x=x, y=y2) + + # Define new theory + fit.addtheory(name="Asymetric gaussian", + function=functions.sum_splitgauss, + parameters=("height", "peak center", "left fwhm", "right fwhm"), + estimate=fitfuns.estimate_splitgauss) + fit.settheory('Asymetric gaussian') + + # Fit + fit.estimate() + fit.runfit() + + w.fillFromFit(fit.fit_results, view='Asymetric gaussians') + + # Plot + pw = PlotWindow(control=True) + pw.addCurve(x, y1, "Gaussians") + pw.addCurve(x, y2, "Asymetric gaussians") + pw.show() + + a.exec_() + + +if __name__ == "__main__": + test() diff --git a/silx/gui/fit/Parameters.py b/silx/gui/fit/Parameters.py new file mode 100644 index 0000000..62e3278 --- /dev/null +++ b/silx/gui/fit/Parameters.py @@ -0,0 +1,882 @@ +# coding: utf-8 +# /*########################################################################## +# Copyright (C) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ######################################################################### */ +"""This module defines a table widget that is specialized in displaying fit +parameter results and associated constraints.""" +__authors__ = ["V.A. Sole", "P. Knobel"] +__license__ = "MIT" +__date__ = "25/11/2016" + +import sys +from collections import OrderedDict + +from silx.gui import qt +from silx.gui.widgets.TableWidget import TableWidget + + +def float_else_zero(sstring): + """Return converted string to float. If conversion fail, return zero. + + :param sstring: String to be converted + :return: ``float(sstrinq)`` if ``sstring`` can be converted to float + (e.g. ``"3.14"``), else ``0`` + """ + try: + return float(sstring) + except ValueError: + return 0 + + +class QComboTableItem(qt.QComboBox): + """:class:`qt.QComboBox` augmented with a ``sigCellChanged`` signal + to emit a tuple of ``(row, column)`` coordinates when the value is + changed. + + This signal can be used to locate the modified combo box in a table. + + :param row: Row number of the table cell containing this widget + :param col: Column number of the table cell containing this widget""" + sigCellChanged = qt.Signal(int, int) + """Signal emitted when this ``QComboBox`` is activated. + A ``(row, column)`` tuple is passed.""" + + def __init__(self, parent=None, row=None, col=None): + self._row = row + self._col = col + qt.QComboBox.__init__(self, parent) + self.activated[int].connect(self._cellChanged) + + def _cellChanged(self, idx): # noqa + self.sigCellChanged.emit(self._row, self._col) + + +class QCheckBoxItem(qt.QCheckBox): + """:class:`qt.QCheckBox` augmented with a ``sigCellChanged`` signal + to emit a tuple of ``(row, column)`` coordinates when the check box has + been clicked on. + + This signal can be used to locate the modified check box in a table. + + :param row: Row number of the table cell containing this widget + :param col: Column number of the table cell containing this widget""" + sigCellChanged = qt.Signal(int, int) + """Signal emitted when this ``QCheckBox`` is clicked. + A ``(row, column)`` tuple is passed.""" + + def __init__(self, parent=None, row=None, col=None): + self._row = row + self._col = col + qt.QCheckBox.__init__(self, parent) + self.clicked.connect(self._cellChanged) + + def _cellChanged(self): + self.sigCellChanged.emit(self._row, self._col) + + +class Parameters(TableWidget): + """:class:`TableWidget` customized to display fit results + and to interact with :class:`FitManager` objects. + + Data and references to cell widgets are kept in a dictionary + attribute :attr:`parameters`. + + :param parent: Parent widget + :param labels: Column headers. If ``None``, default headers will be used. + :type labels: List of strings or None + :param paramlist: List of fit parameters to be displayed for each fitted + peak. + :type paramlist: list[str] or None + """ + def __init__(self, parent=None, paramlist=None): + TableWidget.__init__(self, parent) + self.setContentsMargins(0, 0, 0, 0) + + labels = ['Parameter', 'Estimation', 'Fit Value', 'Sigma', + 'Constraints', 'Min/Parame', 'Max/Factor/Delta'] + tooltips = ["Fit parameter name", + "Estimated value for fit parameter. You can edit this column.", + "Actual value for parameter, after fit", + "Uncertainty (same unit as the parameter)", + "Constraint to be applied to the parameter for fit", + "First parameter for constraint (name of another param or min value)", + "Second parameter for constraint (max value, or factor/delta)"] + + self.columnKeys = ['name', 'estimation', 'fitresult', + 'sigma', 'code', 'val1', 'val2'] + """This list assigns shorter keys to refer to columns than the + displayed labels.""" + + self.__configuring = False + + # column headers and associated tooltips + self.setColumnCount(len(labels)) + + for i, label in enumerate(labels): + item = self.horizontalHeaderItem(i) + if item is None: + item = qt.QTableWidgetItem(label, + qt.QTableWidgetItem.Type) + self.setHorizontalHeaderItem(i, item) + + item.setText(label) + if tooltips is not None: + item.setToolTip(tooltips[i]) + + # resize columns + for col_key in ["name", "estimation", "sigma", "val1", "val2"]: + col_idx = self.columnIndexByField(col_key) + self.resizeColumnToContents(col_idx) + + # Initialize the table with one line per supplied parameter + paramlist = paramlist if paramlist is not None else [] + self.parameters = OrderedDict() + """This attribute stores all the data in an ordered dictionary. + New data can be added using :meth:`newParameterLine`. + Existing data can be modified using :meth:`configureLine` + + Keys of the dictionary are: + + - 'name': parameter name + - 'line': line index for the parameter in the table + - 'estimation' + - 'fitresult' + - 'sigma' + - 'code': constraint code (one of the elements of + :attr:`code_options`) + - 'val1': first parameter related to constraint, formatted + as a string, as typed in the table + - 'val2': second parameter related to constraint, formatted + as a string, as typed in the table + - 'cons1': scalar representation of 'val1' + (e.g. when val1 is the name of a fit parameter, cons1 + will be the line index of this parameter) + - 'cons2': scalar representation of 'val2' + - 'vmin': equal to 'val1' when 'code' is "QUOTED" + - 'vmax': equal to 'val2' when 'code' is "QUOTED" + - 'relatedto': name of related parameter when this parameter + is constrained to another parameter (same as 'val1') + - 'factor': same as 'val2' when 'code' is 'FACTOR' + - 'delta': same as 'val2' when 'code' is 'DELTA' + - 'sum': same as 'val2' when 'code' is 'SUM' + - 'group': group index for the parameter + - 'xmin': data range minimum + - 'xmax': data range maximum + """ + for line, param in enumerate(paramlist): + self.newParameterLine(param, line) + + self.code_options = ["FREE", "POSITIVE", "QUOTED", "FIXED", + "FACTOR", "DELTA", "SUM", "IGNORE", "ADD"] + """Possible values in the combo boxes in the 'Constraints' column. + """ + + # connect signal + self.cellChanged[int, int].connect(self.onCellChanged) + + def newParameterLine(self, param, line): + """Add a line to the :class:`QTableWidget`. + + Each line represents one of the fit parameters for one of + the fitted peaks. + + :param param: Name of the fit parameter + :type param: str + :param line: 0-based line index + :type line: int + """ + # get current number of lines + nlines = self.rowCount() + self.__configuring = True + if line >= nlines: + self.setRowCount(line + 1) + + # default configuration for fit parameters + self.parameters[param] = OrderedDict((('line', line), + ('estimation', '0'), + ('fitresult', ''), + ('sigma', ''), + ('code', 'FREE'), + ('val1', ''), + ('val2', ''), + ('cons1', 0), + ('cons2', 0), + ('vmin', '0'), + ('vmax', '1'), + ('relatedto', ''), + ('factor', '1.0'), + ('delta', '0.0'), + ('sum', '0.0'), + ('group', ''), + ('name', param), + ('xmin', None), + ('xmax', None))) + self.setReadWrite(param, 'estimation') + self.setReadOnly(param, ['name', 'fitresult', 'sigma', 'val1', 'val2']) + + # Constraint codes + a = [] + for option in self.code_options: + a.append(option) + + code_column_index = self.columnIndexByField('code') + cellWidget = self.cellWidget(line, code_column_index) + if cellWidget is None: + cellWidget = QComboTableItem(self, row=line, + col=code_column_index) + cellWidget.addItems(a) + self.setCellWidget(line, code_column_index, cellWidget) + cellWidget.sigCellChanged[int, int].connect(self.onCellChanged) + self.parameters[param]['code_item'] = cellWidget + self.parameters[param]['relatedto_item'] = None + self.__configuring = False + + def columnIndexByField(self, field): + """ + + :param field: Field name (column key) + :return: Index of the column with this field name + """ + return self.columnKeys.index(field) + + def fillFromFit(self, fitresults): + """Fill table with values from a list of dictionaries + (see :attr:`silx.math.fit.fitmanager.FitManager.fit_results`) + + :param fitresults: List of parameters as recorded + in the ``paramlist`` attribute of a :class:`FitManager` object + :type fitresults: list[dict] + """ + self.setRowCount(len(fitresults)) + + # Reinitialize and fill self.parameters + self.parameters = OrderedDict() + for (line, param) in enumerate(fitresults): + self.newParameterLine(param['name'], line) + + for param in fitresults: + name = param['name'] + code = str(param['code']) + if code not in self.code_options: + # convert code from int to descriptive string + code = self.code_options[int(code)] + val1 = param['cons1'] + val2 = param['cons2'] + estimation = param['estimation'] + group = param['group'] + sigma = param['sigma'] + fitresult = param['fitresult'] + + xmin = param.get('xmin') + xmax = param.get('xmax') + + self.configureLine(name=name, + code=code, + val1=val1, val2=val2, + estimation=estimation, + fitresult=fitresult, + sigma=sigma, + group=group, + xmin=xmin, xmax=xmax) + + def getConfiguration(self): + """Return ``FitManager.paramlist`` dictionary + encapsulated in another dictionary""" + return {'parameters': self.getFitResults()} + + def setConfiguration(self, ddict): + """Fill table with values from a ``FitManager.paramlist`` dictionary + encapsulated in another dictionary""" + self.fillFromFit(ddict['parameters']) + + def getFitResults(self): + """Return fit parameters as a list of dictionaries in the format used + by :class:`FitManager` (attribute ``paramlist``). + """ + fitparameterslist = [] + for param in self.parameters: + fitparam = {} + name = param + estimation, [code, cons1, cons2] = self.getEstimationConstraints(name) + buf = str(self.parameters[param]['fitresult']) + xmin = self.parameters[param]['xmin'] + xmax = self.parameters[param]['xmax'] + if len(buf): + fitresult = float(buf) + else: + fitresult = 0.0 + buf = str(self.parameters[param]['sigma']) + if len(buf): + sigma = float(buf) + else: + sigma = 0.0 + buf = str(self.parameters[param]['group']) + if len(buf): + group = float(buf) + else: + group = 0 + fitparam['name'] = name + fitparam['estimation'] = estimation + fitparam['fitresult'] = fitresult + fitparam['sigma'] = sigma + fitparam['group'] = group + fitparam['code'] = code + fitparam['cons1'] = cons1 + fitparam['cons2'] = cons2 + fitparam['xmin'] = xmin + fitparam['xmax'] = xmax + fitparameterslist.append(fitparam) + return fitparameterslist + + def onCellChanged(self, row, col): + """Slot called when ``cellChanged`` signal is emitted. + Checks the validity of the new text in the cell, then calls + :meth:`configureLine` to update the internal ``self.parameters`` + dictionary. + + :param row: Row number of the changed cell (0-based index) + :param col: Column number of the changed cell (0-based index) + """ + if (col != self.columnIndexByField("code")) and (col != -1): + if row != self.currentRow(): + return + if col != self.currentColumn(): + return + if self.__configuring: + return + param = list(self.parameters)[row] + field = self.columnKeys[col] + oldvalue = self.parameters[param][field] + if col != 4: + item = self.item(row, col) + if item is not None: + newvalue = item.text() + else: + newvalue = '' + else: + # this is the combobox + widget = self.cellWidget(row, col) + newvalue = widget.currentText() + if self.validate(param, field, oldvalue, newvalue): + paramdict = {"name": param, field: newvalue} + self.configureLine(**paramdict) + else: + if field == 'code': + # New code not valid, try restoring the old one + index = self.code_options.index(oldvalue) + self.__configuring = True + try: + self.parameters[param]['code_item'].setCurrentIndex(index) + finally: + self.__configuring = False + else: + paramdict = {"name": param, field: oldvalue} + self.configureLine(**paramdict) + + def validate(self, param, field, oldvalue, newvalue): + """Check validity of ``newvalue`` when a cell's value is modified. + + :param param: Fit parameter name + :param field: Column name + :param oldvalue: Cell value before change attempt + :param newvalue: New value to be validated + :return: True if new cell value is valid, else False + """ + if field == 'code': + return self.setCodeValue(param, oldvalue, newvalue) + # FIXME: validate() shouldn't have side effects. Move this bit to configureLine()? + if field == 'val1' and str(self.parameters[param]['code']) in ['DELTA', 'FACTOR', 'SUM']: + _, candidates = self.getRelatedCandidates(param) + # We expect val1 to be a fit parameter name + if str(newvalue) in candidates: + return True + else: + return False + # except for code, val1 and name (which is read-only and does not need + # validation), all fields must always be convertible to float + else: + try: + float(str(newvalue)) + except ValueError: + return False + return True + + def setCodeValue(self, param, oldvalue, newvalue): + """Update 'code' and 'relatedto' fields when code cell is + changed. + + :param param: Fit parameter name + :param oldvalue: Cell value before change attempt + :param newvalue: New value to be validated + :return: ``True`` if code was successfully updated + """ + + if str(newvalue) in ['FREE', 'POSITIVE', 'QUOTED', 'FIXED']: + self.configureLine(name=param, + code=newvalue) + if str(oldvalue) == 'IGNORE': + self.freeRestOfGroup(param) + return True + elif str(newvalue) in ['FACTOR', 'DELTA', 'SUM']: + # I should check here that some parameter is set + best, candidates = self.getRelatedCandidates(param) + if len(candidates) == 0: + return False + self.configureLine(name=param, + code=newvalue, + relatedto=best) + if str(oldvalue) == 'IGNORE': + self.freeRestOfGroup(param) + return True + + elif str(newvalue) == 'IGNORE': + # I should check if the group can be ignored + # for the time being I just fix all of them to ignore + group = int(float(str(self.parameters[param]['group']))) + candidates = [] + for param in self.parameters.keys(): + if group == int(float(str(self.parameters[param]['group']))): + candidates.append(param) + # print candidates + # I should check here if there is any relation to them + for param in candidates: + self.configureLine(name=param, + code=newvalue) + return True + elif str(newvalue) == 'ADD': + group = int(float(str(self.parameters[param]['group']))) + if group == 0: + # One cannot add a background group + return False + i = 0 + for param in self.parameters: + if i <= int(float(str(self.parameters[param]['group']))): + i += 1 + if (group == 0) and (i == 1): # FIXME: why +1? + i += 1 + self.addGroup(i, group) + return False + elif str(newvalue) == 'SHOW': + print(self.getEstimationConstraints(param)) + return False + + def addGroup(self, newg, gtype): + """Add a fit parameter group with the same fit parameters as an + existing group. + + This function is called when the user selects "ADD" in the + "constraints" combobox. + + :param int newg: New group number + :param int gtype: Group number whose parameters we want to copy + + """ + newparam = [] + # loop through parameters until we encounter group number `gtype` + for param in list(self.parameters): + paramgroup = int(float(str(self.parameters[param]['group']))) + # copy parameter names in group number `gtype` + if paramgroup == gtype: + # but replace `gtype` with `newg` + newparam.append(param.rstrip("0123456789") + "%d" % newg) + + xmin = self.parameters[param]['xmin'] + xmax = self.parameters[param]['xmax'] + + # Add new parameters (one table line per parameter) and configureLine each + # one by updating xmin and xmax to the same values as group `gtype` + line = len(list(self.parameters)) + for param in newparam: + self.newParameterLine(param, line) + line += 1 + for param in newparam: + self.configureLine(name=param, group=newg, xmin=xmin, xmax=xmax) + + def freeRestOfGroup(self, workparam): + """Set ``code`` to ``"FREE"`` for all fit parameters belonging to + the same group as ``workparam``. This is done when the entire group + of parameters was previously ignored and one of them has his code + set to something different than ``"IGNORE"``. + + :param workparam: Fit parameter name + """ + if workparam in self.parameters.keys(): + group = int(float(str(self.parameters[workparam]['group']))) + for param in self.parameters: + if param != workparam and\ + group == int(float(str(self.parameters[param]['group']))): + self.configureLine(name=param, + code='FREE', + cons1=0, + cons2=0, + val1='', + val2='') + + def getRelatedCandidates(self, workparam): + """If fit parameter ``workparam`` has a constraint that involves other + fit parameters, find possible candidates and try to guess which one + is the most likely. + + :param workparam: Fit parameter name + :return: (best_candidate, possible_candidates) tuple + :rtype: (str, list[str]) + """ + candidates = [] + for param_name in self.parameters: + if param_name != workparam: + # ignore parameters that are fixed by a constraint + if str(self.parameters[param_name]['code']) not in\ + ['IGNORE', 'FACTOR', 'DELTA', 'SUM']: + candidates.append(param_name) + # take the previous one (before code cell changed) if possible + if str(self.parameters[workparam]['relatedto']) in candidates: + best = str(self.parameters[workparam]['relatedto']) + return best, candidates + # take the first with same base name (after removing numbers) + for param_name in candidates: + basename = param_name.rstrip("0123456789") + try: + pos = workparam.index(basename) + if pos == 0: + best = param_name + return best, candidates + except ValueError: + pass + # take the first + return candidates[0], candidates + + def setReadOnly(self, parameter, fields): + """Make table cells read-only by setting it's flags and omitting + flag ``qt.Qt.ItemIsEditable`` + + :param parameter: Fit parameter names identifying the rows + :type parameter: str or list[str] + :param fields: Field names identifying the columns + :type fields: str or list[str] + """ + editflags = qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled + self.setField(parameter, fields, editflags) + + def setReadWrite(self, parameter, fields): + """Make table cells read-write by setting it's flags including + flag ``qt.Qt.ItemIsEditable`` + + :param parameter: Fit parameter names identifying the rows + :type parameter: str or list[str] + :param fields: Field names identifying the columns + :type fields: str or list[str] + """ + editflags = qt.Qt.ItemIsSelectable |\ + qt.Qt.ItemIsEnabled |\ + qt.Qt.ItemIsEditable + self.setField(parameter, fields, editflags) + + def setField(self, parameter, fields, edit_flags): + """Set text and flags in a table cell. + + :param parameter: Fit parameter names identifying the rows + :type parameter: str or list[str] + :param fields: Field names identifying the columns + :type fields: str or list[str] + :param edit_flags: Flag combination, e.g:: + + qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled | + qt.Qt.ItemIsEditable + """ + if isinstance(parameter, list) or \ + isinstance(parameter, tuple): + paramlist = parameter + else: + paramlist = [parameter] + if isinstance(fields, list) or \ + isinstance(fields, tuple): + fieldlist = fields + else: + fieldlist = [fields] + + # Set _configuring flag to ignore cellChanged signals in + # self.onCellChanged + _oldvalue = self.__configuring + self.__configuring = True + + # 2D loop through parameter list and field list + # to update their cells + for param in paramlist: + row = list(self.parameters.keys()).index(param) + for field in fieldlist: + col = self.columnIndexByField(field) + if field != 'code': + key = field + "_item" + item = self.item(row, col) + if item is None: + item = qt.QTableWidgetItem() + item.setText(self.parameters[param][field]) + self.setItem(row, col, item) + else: + item.setText(self.parameters[param][field]) + self.parameters[param][key] = item + item.setFlags(edit_flags) + + # Restore previous _configuring flag + self.__configuring = _oldvalue + + def configureLine(self, name, code=None, val1=None, val2=None, + sigma=None, estimation=None, fitresult=None, + group=None, xmin=None, xmax=None, relatedto=None, + cons1=None, cons2=None): + """This function updates values in a line of the table + + :param name: Name of the parameter (serves as unique identifier for + a line). + :param code: Constraint code *FREE, FIXED, POSITIVE, DELTA, FACTOR, + SUM, QUOTED, IGNORE* + :param val1: Constraint 1 (can be the index or name of another + parameter for code *DELTA, FACTOR, SUM*, or a min value + for code *QUOTED*) + :param val2: Constraint 2 + :param sigma: Standard deviation for a fit parameter + :param estimation: Estimated initial value for a fit parameter (used + as input to iterative fit) + :param fitresult: Final result of fit + :param group: Group number of a fit parameter (peak number when doing + multi-peak fitting, as each peak corresponds to a group + of several consecutive parameters) + :param xmin: + :param xmax: + :param relatedto: Index or name of another fit parameter + to which this parameter is related to (constraints) + :param cons1: similar meaning to ``val1``, but is always a number + :param cons2: similar meaning to ``val2``, but is always a number + :return: + """ + paramlist = list(self.parameters.keys()) + + if name not in self.parameters: + raise KeyError("'%s' is not in the parameter list" % name) + + # update code first, if specified + if code is not None: + code = str(code) + self.parameters[name]['code'] = code + # update combobox + index = self.parameters[name]['code_item'].findText(code) + self.parameters[name]['code_item'].setCurrentIndex(index) + else: + # set code to previous value, used later for setting val1 val2 + code = self.parameters[name]['code'] + + # val1 and sigma have special formats + if val1 is not None: + fmt = None if self.parameters[name]['code'] in\ + ['DELTA', 'FACTOR', 'SUM'] else "%8g" + self._updateField(name, "val1", val1, fmat=fmt) + + if sigma is not None: + self._updateField(name, "sigma", sigma, fmat="%6.3g") + + # other fields are formatted as "%8g" + keys_params = (("val2", val2), ("estimation", estimation), + ("fitresult", fitresult)) + for key, value in keys_params: + if value is not None: + self._updateField(name, key, value, fmat="%8g") + + # the rest of the parameters are treated as strings and don't need + # validation + keys_params = (("group", group), ("xmin", xmin), + ("xmax", xmax), ("relatedto", relatedto), + ("cons1", cons1), ("cons2", cons2)) + for key, value in keys_params: + if value is not None: + self.parameters[name][key] = str(value) + + # val1 and val2 have different meanings depending on the code + if code == 'QUOTED': + if val1 is not None: + self.parameters[name]['vmin'] = self.parameters[name]['val1'] + else: + self.parameters[name]['val1'] = self.parameters[name]['vmin'] + if val2 is not None: + self.parameters[name]['vmax'] = self.parameters[name]['val2'] + else: + self.parameters[name]['val2'] = self.parameters[name]['vmax'] + + # cons1 and cons2 are scalar representations of val1 and val2 + self.parameters[name]['cons1'] =\ + float_else_zero(self.parameters[name]['val1']) + self.parameters[name]['cons2'] =\ + float_else_zero(self.parameters[name]['val2']) + + # cons1, cons2 = min(val1, val2), max(val1, val2) + if self.parameters[name]['cons1'] > self.parameters[name]['cons2']: + self.parameters[name]['cons1'], self.parameters[name]['cons2'] =\ + self.parameters[name]['cons2'], self.parameters[name]['cons1'] + + elif code in ['DELTA', 'SUM', 'FACTOR']: + # For these codes, val1 is the fit parameter name on which the + # constraint depends + if val1 is not None and val1 in paramlist: + self.parameters[name]['relatedto'] = self.parameters[name]["val1"] + + elif val1 is not None: + # val1 could be the index of the fit parameter + try: + self.parameters[name]['relatedto'] = paramlist[int(val1)] + except ValueError: + self.parameters[name]['relatedto'] = self.parameters[name]["val1"] + + elif relatedto is not None: + # code changed, val1 not specified but relatedto specified: + # set val1 to relatedto (pre-fill best guess) + self.parameters[name]["val1"] = relatedto + + # update fields "delta", "sum" or "factor" + key = code.lower() + self.parameters[name][key] = self.parameters[name]["val2"] + + # FIXME: val1 is sometimes specified as an index rather than a param name + self.parameters[name]['val1'] = self.parameters[name]['relatedto'] + + # cons1 is the index of the fit parameter in the ordered dictionary + if self.parameters[name]['val1'] in paramlist: + self.parameters[name]['cons1'] =\ + paramlist.index(self.parameters[name]['val1']) + + # cons2 is the constraint value (factor, delta or sum) + try: + self.parameters[name]['cons2'] =\ + float(str(self.parameters[name]['val2'])) + except ValueError: + self.parameters[name]['cons2'] = 1.0 if code == "FACTOR" else 0.0 + + elif code in ['FREE', 'POSITIVE', 'IGNORE', 'FIXED']: + self.parameters[name]['val1'] = "" + self.parameters[name]['val2'] = "" + self.parameters[name]['cons1'] = 0 + self.parameters[name]['cons2'] = 0 + + self._updateCellRWFlags(name, code) + + def _updateField(self, name, field, value, fmat=None): + """Update field in ``self.parameters`` dictionary, if the new value + is valid. + + :param name: Fit parameter name + :param field: Field name + :param value: New value to assign + :type value: String + :param fmat: Format string (e.g. "%8g") to be applied if value represents + a scalar. If ``None``, format is not modified. If ``value`` is an + empty string, ``fmat`` is ignored. + """ + if value is not None: + oldvalue = self.parameters[name][field] + if fmat is not None: + newvalue = fmat % float(value) if value != "" else "" + else: + newvalue = value + self.parameters[name][field] = newvalue if\ + self.validate(name, field, oldvalue, newvalue) else\ + oldvalue + + def _updateCellRWFlags(self, name, code=None): + """Set read-only or read-write flags in a row, + depending on the constraint code + + :param name: Fit parameter name identifying the row + :param code: Constraint code, in `'FREE', 'POSITIVE', 'IGNORE',` + `'FIXED', 'FACTOR', 'DELTA', 'SUM', 'ADD'` + :return: + """ + if code in ['FREE', 'POSITIVE', 'IGNORE', 'FIXED']: + self.setReadWrite(name, 'estimation') + self.setReadOnly(name, ['fitresult', 'sigma', 'val1', 'val2']) + else: + self.setReadWrite(name, ['estimation', 'val1', 'val2']) + self.setReadOnly(name, ['fitresult', 'sigma']) + + def getEstimationConstraints(self, param): + """ + Return tuple ``(estimation, constraints)`` where ``estimation`` is the + value in the ``estimate`` field and ``constraints`` are the relevant + constraints according to the active code + """ + estimation = None + constraints = None + if param in self.parameters.keys(): + buf = str(self.parameters[param]['estimation']) + if len(buf): + estimation = float(buf) + else: + estimation = 0 + if str(self.parameters[param]['code']) in self.code_options: + code = self.code_options.index( + str(self.parameters[param]['code'])) + else: + code = str(self.parameters[param]['code']) + cons1 = self.parameters[param]['cons1'] + cons2 = self.parameters[param]['cons2'] + constraints = [code, cons1, cons2] + return estimation, constraints + + +def main(args): + from silx.math.fit import fittheories + from silx.math.fit import fitmanager + try: + from PyMca5 import PyMcaDataDir + except ImportError: + raise ImportError("This demo requires PyMca data. Install PyMca5.") + import numpy + import os + app = qt.QApplication(args) + tab = Parameters(paramlist=['Height', 'Position', 'FWHM']) + tab.showGrid() + tab.configureLine(name='Height', estimation='1234', group=0) + tab.configureLine(name='Position', code='FIXED', group=1) + tab.configureLine(name='FWHM', group=1) + + y = numpy.loadtxt(os.path.join(PyMcaDataDir.PYMCA_DATA_DIR, + "XRFSpectrum.mca")) # FIXME + + x = numpy.arange(len(y)) * 0.0502883 - 0.492773 + fit = fitmanager.FitManager() + fit.setdata(x=x, y=y, xmin=20, xmax=150) + + fit.loadtheories(fittheories) + + fit.settheory('ahypermet') + fit.configure(Yscaling=1., + PositiveFwhmFlag=True, + PositiveHeightAreaFlag=True, + FwhmPoints=16, + QuotedPositionFlag=1, + HypermetTails=1) + fit.setbackground('Linear') + fit.estimate() + fit.runfit() + tab.fillFromFit(fit.fit_results) + tab.show() + app.exec_() + +if __name__ == "__main__": + main(sys.argv) diff --git a/silx/gui/fit/__init__.py b/silx/gui/fit/__init__.py new file mode 100644 index 0000000..e4fd3ab --- /dev/null +++ b/silx/gui/fit/__init__.py @@ -0,0 +1,28 @@ +# coding: utf-8 +# /*########################################################################## +# Copyright (C) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "07/07/2016" + +from .FitWidget import FitWidget diff --git a/silx/gui/fit/setup.py b/silx/gui/fit/setup.py new file mode 100644 index 0000000..6672363 --- /dev/null +++ b/silx/gui/fit/setup.py @@ -0,0 +1,43 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "21/07/2016" + + +from numpy.distutils.misc_util import Configuration + + +def configuration(parent_package='', top_path=None): + config = Configuration('fit', parent_package, top_path) + config.add_subpackage('test') + + return config + + +if __name__ == "__main__": + from numpy.distutils.core import setup + + setup(configuration=configuration) diff --git a/silx/gui/fit/test/__init__.py b/silx/gui/fit/test/__init__.py new file mode 100644 index 0000000..2236d64 --- /dev/null +++ b/silx/gui/fit/test/__init__.py @@ -0,0 +1,43 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +import unittest + +from .testFitWidget import suite as testFitWidgetSuite +from .testFitConfig import suite as testFitConfigSuite +from .testBackgroundWidget import suite as testBackgroundWidgetSuite + + +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "21/07/2016" + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTests( + [testFitWidgetSuite(), + testFitConfigSuite(), + testBackgroundWidgetSuite()]) + return test_suite diff --git a/silx/gui/fit/test/testBackgroundWidget.py b/silx/gui/fit/test/testBackgroundWidget.py new file mode 100644 index 0000000..2e366e4 --- /dev/null +++ b/silx/gui/fit/test/testBackgroundWidget.py @@ -0,0 +1,83 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +import unittest + +from ...test.utils import TestCaseQt + +from .. import BackgroundWidget + +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "05/12/2016" + + +class TestBackgroundWidget(TestCaseQt): + def setUp(self): + super(TestBackgroundWidget, self).setUp() + self.bgdialog = BackgroundWidget.BackgroundDialog() + self.bgdialog.setData(list([0, 1, 2, 3]), + list([0, 1, 4, 8])) + self.qWaitForWindowExposed(self.bgdialog) + + def tearDown(self): + del self.bgdialog + super(TestBackgroundWidget, self).tearDown() + + def testShow(self): + self.bgdialog.show() + self.bgdialog.hide() + + def testAccept(self): + self.bgdialog.accept() + self.assertTrue(self.bgdialog.result()) + + def testReject(self): + self.bgdialog.reject() + self.assertFalse(self.bgdialog.result()) + + def testDefaultOutput(self): + self.bgdialog.accept() + output = self.bgdialog.output + + for key in ["algorithm", "StripThreshold", "SnipWidth", + "StripIterations", "StripWidth", "SmoothingFlag", + "SmoothingWidth", "AnchorsFlag", "AnchorsList"]: + self.assertIn(key, output) + + self.assertFalse(output["AnchorsFlag"]) + self.assertEqual(output["StripWidth"], 1) + self.assertEqual(output["SmoothingFlag"], False) + self.assertEqual(output["SmoothingWidth"], 3) + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestBackgroundWidget)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/fit/test/testFitConfig.py b/silx/gui/fit/test/testFitConfig.py new file mode 100644 index 0000000..eea35cc --- /dev/null +++ b/silx/gui/fit/test/testFitConfig.py @@ -0,0 +1,95 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for :class:`FitConfig`""" + +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "05/12/2016" + +import unittest + +from ...test.utils import TestCaseQt +from .. import FitConfig + + +class TestFitConfig(TestCaseQt): + """Basic test for FitWidget""" + + def setUp(self): + super(TestFitConfig, self).setUp() + self.fit_config = FitConfig.getFitConfigDialog(modal=False) + self.qWaitForWindowExposed(self.fit_config) + + def tearDown(self): + del self.fit_config + super(TestFitConfig, self).tearDown() + + def testShow(self): + self.fit_config.show() + self.fit_config.hide() + + def testAccept(self): + self.fit_config.accept() + self.assertTrue(self.fit_config.result()) + + def testReject(self): + self.fit_config.reject() + self.assertFalse(self.fit_config.result()) + + def testDefaultOutput(self): + self.fit_config.accept() + output = self.fit_config.output + + for key in ["AutoFwhm", + "PositiveHeightAreaFlag", + "QuotedPositionFlag", + "PositiveFwhmFlag", + "SameFwhmFlag", + "QuotedEtaFlag", + "NoConstraintsFlag", + "FwhmPoints", + "Sensitivity", + "Yscaling", + "ForcePeakPresence", + "StripBackgroundFlag", + "StripWidth", + "StripIterations", + "StripThreshold", + "SmoothingFlag"]: + self.assertIn(key, output) + + self.assertTrue(output["AutoFwhm"]) + self.assertEqual(output["StripWidth"], 2) + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestFitConfig)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/fit/test/testFitWidget.py b/silx/gui/fit/test/testFitWidget.py new file mode 100644 index 0000000..d542fd0 --- /dev/null +++ b/silx/gui/fit/test/testFitWidget.py @@ -0,0 +1,135 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for :class:`FitWidget`""" + +import unittest + +from ...test.utils import TestCaseQt + +from ... import qt +from .. import FitWidget + +from ....math.fit.fittheory import FitTheory +from ....math.fit.fitmanager import FitManager + +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "05/12/2016" + + +class TestFitWidget(TestCaseQt): + """Basic test for FitWidget""" + + def setUp(self): + super(TestFitWidget, self).setUp() + self.fit_widget = FitWidget() + self.fit_widget.show() + self.qWaitForWindowExposed(self.fit_widget) + + def tearDown(self): + self.fit_widget.setAttribute(qt.Qt.WA_DeleteOnClose) + self.fit_widget.close() + del self.fit_widget + super(TestFitWidget, self).tearDown() + + def testShow(self): + pass + + def testInteract(self): + self.mouseClick(self.fit_widget, qt.Qt.LeftButton) + self.keyClick(self.fit_widget, qt.Qt.Key_Enter) + self.qapp.processEvents() + + def testCustomConfigWidget(self): + class CustomConfigWidget(qt.QDialog): + def __init__(self): + qt.QDialog.__init__(self) + self.setModal(True) + self.ok = qt.QPushButton("ok", self) + self.ok.clicked.connect(self.accept) + cancel = qt.QPushButton("cancel", self) + cancel.clicked.connect(self.reject) + layout = qt.QVBoxLayout(self) + layout.addWidget(self.ok) + layout.addWidget(cancel) + self.output = {"hello": "world"} + + def fitfun(x, a, b): + return a * x + b + + x = list(range(0, 100)) + y = [fitfun(x_, 2, 3) for x_ in x] + + def conf(**kw): + return {"spam": "eggs", + "hello": "world!"} + + theory = FitTheory( + function=fitfun, + parameters=["a", "b"], + configure=conf) + + fitmngr = FitManager() + fitmngr.setdata(x, y) + fitmngr.addtheory("foo", theory) + fitmngr.addtheory("bar", theory) + fitmngr.addbgtheory("spam", theory) + + fw = FitWidget(fitmngr=fitmngr) + fw.associateConfigDialog("spam", CustomConfigWidget(), + theory_is_background=True) + fw.associateConfigDialog("foo", CustomConfigWidget()) + fw.show() + self.qWaitForWindowExposed(fw) + + fw.bgconfigdialogs["spam"].accept() + self.assertTrue(fw.bgconfigdialogs["spam"].result()) + + self.assertEqual(fw.bgconfigdialogs["spam"].output, + {"hello": "world"}) + + fw.bgconfigdialogs["spam"].reject() + self.assertFalse(fw.bgconfigdialogs["spam"].result()) + + fw.configdialogs["foo"].accept() + self.assertTrue(fw.configdialogs["foo"].result()) + + # todo: figure out how to click fw.configdialog.ok to close dialog + # open dialog + # self.mouseClick(fw.guiConfig.FunConfigureButton, qt.Qt.LeftButton) + # clove dialog + # self.mouseClick(fw.configdialogs["foo"].ok, qt.Qt.LeftButton) + # self.qapp.processEvents() + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestFitWidget)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/hdf5/Hdf5HeaderView.py b/silx/gui/hdf5/Hdf5HeaderView.py new file mode 100644 index 0000000..5912230 --- /dev/null +++ b/silx/gui/hdf5/Hdf5HeaderView.py @@ -0,0 +1,192 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "08/11/2016" + + +from .. import qt + +QTVERSION = qt.qVersion() + + +class Hdf5HeaderView(qt.QHeaderView): + """ + Default HDF5 header + + Manage auto-resize and context menu to display/hide columns + """ + + def __init__(self, orientation, parent=None): + """ + Constructor + + :param orientation qt.Qt.Orientation: Orientation of the header + :param parent qt.QWidget: Parent of the widget + """ + super(Hdf5HeaderView, self).__init__(orientation, parent) + self.setContextMenuPolicy(qt.Qt.CustomContextMenu) + self.customContextMenuRequested.connect(self.__createContextMenu) + + # default initialization done by QTreeView for it's own header + if QTVERSION < "5.0": + self.setClickable(True) + self.setMovable(True) + else: + self.setSectionsClickable(True) + self.setSectionsMovable(True) + self.setDefaultAlignment(qt.Qt.AlignLeft | qt.Qt.AlignVCenter) + self.setStretchLastSection(True) + + self.__auto_resize = True + self.__hide_columns_popup = True + + def setModel(self, model): + """Override model to configure view when a model is expected + + `qt.QHeaderView.setResizeMode` expect already existing columns + to work. + + :param model qt.QAbstractItemModel: A model + """ + super(Hdf5HeaderView, self).setModel(model) + self.__updateAutoResize() + + def __updateAutoResize(self): + """Update the view according to the state of the auto-resize""" + if QTVERSION < "5.0": + setResizeMode = self.setResizeMode + else: + setResizeMode = self.setSectionResizeMode + + if self.__auto_resize: + setResizeMode(0, qt.QHeaderView.ResizeToContents) + setResizeMode(1, qt.QHeaderView.ResizeToContents) + setResizeMode(2, qt.QHeaderView.ResizeToContents) + setResizeMode(3, qt.QHeaderView.Interactive) + setResizeMode(4, qt.QHeaderView.Interactive) + setResizeMode(5, qt.QHeaderView.ResizeToContents) + else: + setResizeMode(0, qt.QHeaderView.Interactive) + setResizeMode(1, qt.QHeaderView.Interactive) + setResizeMode(2, qt.QHeaderView.Interactive) + setResizeMode(3, qt.QHeaderView.Interactive) + setResizeMode(4, qt.QHeaderView.Interactive) + setResizeMode(5, qt.QHeaderView.Interactive) + + def setAutoResizeColumns(self, autoResize): + """Enable/disable auto-resize. When auto-resized, the header take care + of the content of the column to set fixed size of some of them, or to + auto fix the size according to the content. + + :param autoResize bool: Enable/disable auto-resize + """ + if self.__auto_resize == autoResize: + return + self.__auto_resize = autoResize + self.__updateAutoResize() + + def hasAutoResizeColumns(self): + """Is auto-resize enabled. + + :rtype: bool + """ + return self.__auto_resize + + autoResizeColumns = qt.Property(bool, hasAutoResizeColumns, setAutoResizeColumns) + """Property to enable/disable auto-resize.""" + + def setEnableHideColumnsPopup(self, enablePopup): + """Enable/disable a popup to allow to hide/show each column of the + model. + + :param bool enablePopup: Enable/disable popup to hide/show columns + """ + self.__hide_columns_popup = enablePopup + + def hasHideColumnsPopup(self): + """Is popup to hide/show columns is enabled. + + :rtype: bool + """ + return self.__hide_columns_popup + + enableHideColumnsPopup = qt.Property(bool, hasHideColumnsPopup, setAutoResizeColumns) + """Property to enable/disable popup allowing to hide/show columns.""" + + def __genHideSectionEvent(self, column): + """Generate a callback which change the column visibility according to + the event parameter + + :param int column: logical id of the column + :rtype: callable + """ + return lambda checked: self.setSectionHidden(column, not checked) + + def __createContextMenu(self, pos): + """Callback to create and display a context menu + + :param pos qt.QPoint: Requested position for the context menu + """ + if not self.__hide_columns_popup: + return + + model = self.model() + if model.columnCount() > 1: + menu = qt.QMenu(self) + menu.setTitle("Display/hide columns") + + action = qt.QAction("Display/hide column", self) + action.setEnabled(False) + menu.addAction(action) + + for column in range(model.columnCount()): + if column == 0: + # skip the main column + continue + text = model.headerData(column, qt.Qt.Horizontal, qt.Qt.DisplayRole) + action = qt.QAction("%s displayed" % text, self) + action.setCheckable(True) + action.setChecked(not self.isSectionHidden(column)) + action.toggled.connect(self.__genHideSectionEvent(column)) + menu.addAction(action) + + menu.popup(self.viewport().mapToGlobal(pos)) + + def setSections(self, logicalIndexes): + """ + Defines order of visible sections by logical indexes. + + Use `Hdf5TreeModel.NAME_COLUMN` to set the list. + + :param list logicalIndexes: List of logical indexes to display + """ + for pos, column_id in enumerate(logicalIndexes): + current_pos = self.visualIndex(column_id) + self.moveSection(current_pos, pos) + self.setSectionHidden(column_id, False) + for column_id in set(range(self.model().columnCount())) - set(logicalIndexes): + self.setSectionHidden(column_id, True) diff --git a/silx/gui/hdf5/Hdf5Item.py b/silx/gui/hdf5/Hdf5Item.py new file mode 100644 index 0000000..40793a4 --- /dev/null +++ b/silx/gui/hdf5/Hdf5Item.py @@ -0,0 +1,421 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "20/01/2017" + + +import numpy +import logging +import collections +from .. import qt +from .. import icons +from . import _utils +from .Hdf5Node import Hdf5Node +import silx.io.utils +from silx.gui.data.TextFormatter import TextFormatter + +_logger = logging.getLogger(__name__) + +try: + import h5py +except ImportError as e: + _logger.error("Module %s requires h5py", __name__) + raise e + +_formatter = TextFormatter() + + +class Hdf5Item(Hdf5Node): + """Subclass of :class:`qt.QStandardItem` to represent an HDF5-like + item (dataset, file, group or link) as an element of a HDF5-like + tree structure. + """ + + def __init__(self, text, obj, parent, key=None, h5pyClass=None, isBroken=False, populateAll=False): + """ + :param str text: text displayed + :param object obj: Pointer to h5py data. See the `obj` attribute. + """ + self.__obj = obj + self.__key = key + self.__h5pyClass = h5pyClass + self.__isBroken = isBroken + self.__error = None + self.__text = text + Hdf5Node.__init__(self, parent, populateAll=populateAll) + + @property + def obj(self): + if self.__key: + self.__initH5pyObject() + return self.__obj + + @property + def basename(self): + return self.__text + + @property + def h5pyClass(self): + """Returns the class of the stored object. + + When the object is in lazy loading, this method should be able to + return the type of the futrue loaded object. It allows to delay the + real load of the object. + + :rtype: h5py.File or h5py.Dataset or h5py.Group + """ + if self.__h5pyClass is None: + self.__h5pyClass = silx.io.utils.get_h5py_class(self.obj) + return self.__h5pyClass + + def isGroupObj(self): + """Returns true if the stored HDF5 object is a group (contains sub + groups or datasets). + + :rtype: bool + """ + return issubclass(self.h5pyClass, h5py.Group) + + def isBrokenObj(self): + """Returns true if the stored HDF5 object is broken. + + The stored object is then an h5py link (external or not) which point + to nowhere (tbhe external file is not here, the expected dataset is + still not on the file...) + + :rtype: bool + """ + return self.__isBroken + + def _expectedChildCount(self): + if self.isGroupObj(): + return len(self.obj) + return 0 + + def __initH5pyObject(self): + """Lazy load of the HDF5 node. It is reached from the parent node + with the key of the node.""" + parent_obj = self.parent.obj + + try: + obj = parent_obj.get(self.__key) + except Exception as e: + _logger.debug("Internal h5py error", exc_info=True) + try: + self.__obj = parent_obj.get(self.__key, getlink=True) + except Exception: + self.__obj = None + self.__error = e.args[0] + self.__isBroken = True + else: + if obj is None: + # that's a broken link + self.__obj = parent_obj.get(self.__key, getlink=True) + + # TODO monkey-patch file (ask that in h5py for consistency) + if not hasattr(self.__obj, "name"): + parent_name = parent_obj.name + if parent_name == "/": + self.__obj.name = "/" + self.__key + else: + self.__obj.name = parent_name + "/" + self.__key + # TODO monkey-patch file (ask that in h5py for consistency) + if not hasattr(self.__obj, "file"): + self.__obj.file = parent_obj.file + + if isinstance(self.__obj, h5py.ExternalLink): + message = "External link broken. Path %s::%s does not exist" % (self.__obj.filename, self.__obj.path) + elif isinstance(self.__obj, h5py.SoftLink): + message = "Soft link broken. Path %s does not exist" % (self.__obj.path) + else: + name = self.obj.__class__.__name__.split(".")[-1].capitalize() + message = "%s broken" % (name) + self.__error = message + self.__isBroken = True + else: + self.__obj = obj + + self.__key = None + + def _populateChild(self, populateAll=False): + if self.isGroupObj(): + for name in self.obj: + try: + class_ = self.obj.get(name, getclass=True) + has_error = False + except Exception as e: + _logger.error("Internal h5py error", exc_info=True) + try: + class_ = self.obj.get(name, getclass=True, getlink=True) + except Exception as e: + class_ = h5py.HardLink + has_error = True + item = Hdf5Item(text=name, obj=None, parent=self, key=name, h5pyClass=class_, isBroken=has_error) + self.appendChild(item) + + def hasChildren(self): + """Retuens true of this node have chrild. + + :rtype: bool + """ + if not self.isGroupObj(): + return False + return Hdf5Node.hasChildren(self) + + def _getDefaultIcon(self): + """Returns the icon displayed by the main column. + + :rtype: qt.QIcon + """ + style = qt.QApplication.style() + if self.__isBroken: + icon = style.standardIcon(qt.QStyle.SP_MessageBoxCritical) + return icon + class_ = self.h5pyClass + if issubclass(class_, h5py.File): + return style.standardIcon(qt.QStyle.SP_FileIcon) + elif issubclass(class_, h5py.Group): + return style.standardIcon(qt.QStyle.SP_DirIcon) + elif issubclass(class_, h5py.SoftLink): + return style.standardIcon(qt.QStyle.SP_DirLinkIcon) + elif issubclass(class_, h5py.ExternalLink): + return style.standardIcon(qt.QStyle.SP_FileLinkIcon) + elif issubclass(class_, h5py.Dataset): + if len(self.obj.shape) < 4: + name = "item-%ddim" % len(self.obj.shape) + else: + name = "item-ndim" + if str(self.obj.dtype) == "object": + name = "item-object" + icon = icons.getQIcon(name) + return icon + return None + + def _humanReadableShape(self, dataset): + if dataset.shape == tuple(): + return "scalar" + shape = [str(i) for i in dataset.shape] + text = u" \u00D7 ".join(shape) + return text + + def _humanReadableValue(self, dataset): + if dataset.shape == tuple(): + numpy_object = dataset[()] + text = _formatter.toString(numpy_object) + else: + if dataset.size < 5 and dataset.compression is None: + numpy_object = dataset[0:5] + text = _formatter.toString(numpy_object) + else: + dimension = len(dataset.shape) + if dataset.compression is not None: + text = "Compressed %dD data" % dimension + else: + text = "%dD data" % dimension + return text + + def _humanReadableDType(self, dtype, full=False): + if dtype.type == numpy.string_: + text = "string" + elif dtype.type == numpy.unicode_: + text = "string" + elif dtype.type == numpy.object_: + text = "object" + elif dtype.type == numpy.bool_: + text = "bool" + elif dtype.type == numpy.void: + if dtype.fields is None: + text = "raw" + else: + if not full: + text = "compound" + else: + compound = [d[0] for d in dtype.fields.values()] + compound = [self._humanReadableDType(d) for d in compound] + text = "compound(%s)" % ", ".join(compound) + else: + text = str(dtype) + return text + + def _humanReadableType(self, dataset, full=False): + return self._humanReadableDType(dataset.dtype, full) + + def _setTooltipAttributes(self, attributeDict): + """ + Add key/value attributes that will be displayed in the item tooltip + + :param Dict[str,str] attributeDict: Key/value attributes + """ + if issubclass(self.h5pyClass, h5py.Dataset): + attributeDict["Title"] = "HDF5 Dataset" + attributeDict["Name"] = self.basename + attributeDict["Path"] = self.obj.name + attributeDict["Shape"] = self._humanReadableShape(self.obj) + attributeDict["Value"] = self._humanReadableValue(self.obj) + attributeDict["Data type"] = self._humanReadableType(self.obj, full=True) + elif issubclass(self.h5pyClass, h5py.Group): + attributeDict["Title"] = "HDF5 Group" + attributeDict["Name"] = self.basename + attributeDict["Path"] = self.obj.name + elif issubclass(self.h5pyClass, h5py.File): + attributeDict["Title"] = "HDF5 File" + attributeDict["Name"] = self.basename + attributeDict["Path"] = "/" + elif isinstance(self.obj, h5py.ExternalLink): + attributeDict["Title"] = "HDF5 External Link" + attributeDict["Name"] = self.basename + attributeDict["Path"] = self.obj.name + attributeDict["Linked path"] = self.obj.path + attributeDict["Linked file"] = self.obj.filename + elif isinstance(self.obj, h5py.SoftLink): + attributeDict["Title"] = "HDF5 Soft Link" + attributeDict["Name"] = self.basename + attributeDict["Path"] = self.obj.name + attributeDict["Linked path"] = self.obj.path + else: + pass + + def _getDefaultTooltip(self): + """Returns the default tooltip + + :rtype: str + """ + if self.__error is not None: + self.obj # lazy loading of the object + return self.__error + + attrs = collections.OrderedDict() + self._setTooltipAttributes(attrs) + + title = attrs.pop("Title", None) + if len(attrs) > 0: + tooltip = _utils.htmlFromDict(attrs, title=title) + else: + tooltip = "" + + return tooltip + + def dataName(self, role): + """Data for the name column""" + if role == qt.Qt.TextAlignmentRole: + return qt.Qt.AlignTop | qt.Qt.AlignLeft + if role == qt.Qt.DisplayRole: + return self.__text + if role == qt.Qt.DecorationRole: + return self._getDefaultIcon() + if role == qt.Qt.ToolTipRole: + return self._getDefaultTooltip() + return None + + def dataType(self, role): + """Data for the type column""" + if role == qt.Qt.DecorationRole: + return None + if role == qt.Qt.TextAlignmentRole: + return qt.Qt.AlignTop | qt.Qt.AlignLeft + if role == qt.Qt.DisplayRole: + if self.__error is not None: + return "" + class_ = self.h5pyClass + if issubclass(class_, h5py.Dataset): + text = self._humanReadableType(self.obj) + else: + text = "" + return text + + return None + + def dataShape(self, role): + """Data for the shape column""" + if role == qt.Qt.DecorationRole: + return None + if role == qt.Qt.TextAlignmentRole: + return qt.Qt.AlignTop | qt.Qt.AlignLeft + if role == qt.Qt.DisplayRole: + if self.__error is not None: + return "" + class_ = self.h5pyClass + if not issubclass(class_, h5py.Dataset): + return "" + return self._humanReadableShape(self.obj) + return None + + def dataValue(self, role): + """Data for the value column""" + if role == qt.Qt.DecorationRole: + return None + if role == qt.Qt.TextAlignmentRole: + return qt.Qt.AlignTop | qt.Qt.AlignLeft + if role == qt.Qt.DisplayRole: + if self.__error is not None: + return "" + if not issubclass(self.h5pyClass, h5py.Dataset): + return "" + return self._humanReadableValue(self.obj) + return None + + def dataDescription(self, role): + """Data for the description column""" + if role == qt.Qt.DecorationRole: + return None + if role == qt.Qt.TextAlignmentRole: + return qt.Qt.AlignTop | qt.Qt.AlignLeft + if role == qt.Qt.DisplayRole: + if self.__isBroken: + self.obj # lazy loading of the object + return self.__error + if "desc" in self.obj.attrs: + text = self.obj.attrs["desc"] + else: + return "" + return text + if role == qt.Qt.ToolTipRole: + if self.__error is not None: + self.obj # lazy loading of the object + self.__initH5pyObject() + return self.__error + if "desc" in self.obj.attrs: + text = self.obj.attrs["desc"] + else: + return "" + return "Description: %s" % text + return None + + def dataNode(self, role): + """Data for the node column""" + if role == qt.Qt.DecorationRole: + return None + if role == qt.Qt.TextAlignmentRole: + return qt.Qt.AlignTop | qt.Qt.AlignLeft + if role == qt.Qt.DisplayRole: + class_ = self.h5pyClass + text = class_.__name__.split(".")[-1] + return text + if role == qt.Qt.ToolTipRole: + class_ = self.h5pyClass + return "Class name: %s" % self.__class__ + return None diff --git a/silx/gui/hdf5/Hdf5LoadingItem.py b/silx/gui/hdf5/Hdf5LoadingItem.py new file mode 100644 index 0000000..4467366 --- /dev/null +++ b/silx/gui/hdf5/Hdf5LoadingItem.py @@ -0,0 +1,68 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "23/09/2016" + + +from .. import qt +from .Hdf5Node import Hdf5Node + + +class Hdf5LoadingItem(Hdf5Node): + """Item displayed when an Hdf5Node is loading. + + At the end of the loading this item is replaced by the loaded one. + """ + + def __init__(self, text, parent, animatedIcon): + """Constructor""" + Hdf5Node.__init__(self, parent) + self.__text = text + self.__animatedIcon = animatedIcon + self.__animatedIcon.register(self) + + @property + def obj(self): + return None + + def dataName(self, role): + if role == qt.Qt.DecorationRole: + return self.__animatedIcon.currentIcon() + if role == qt.Qt.TextAlignmentRole: + return qt.Qt.AlignTop | qt.Qt.AlignLeft + if role == qt.Qt.DisplayRole: + return self.__text + return None + + def dataDescription(self, role): + if role == qt.Qt.DecorationRole: + return None + if role == qt.Qt.TextAlignmentRole: + return qt.Qt.AlignTop | qt.Qt.AlignLeft + if role == qt.Qt.DisplayRole: + return "Loading..." + return None diff --git a/silx/gui/hdf5/Hdf5Node.py b/silx/gui/hdf5/Hdf5Node.py new file mode 100644 index 0000000..31bb097 --- /dev/null +++ b/silx/gui/hdf5/Hdf5Node.py @@ -0,0 +1,210 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "23/09/2016" + + +class Hdf5Node(object): + """Abstract tree node + + It provides link to the childs and to the parents, and a link to an + external object. + """ + def __init__(self, parent=None, populateAll=False): + """ + Constructor + + :param Hdf5Node parent: Parent of the node, if exists, else None + :param bool populateAll: If true, populate all the tree node. Else + everything is lazy loaded. + """ + self.__child = None + self.__parent = parent + if populateAll: + self.__child = [] + self._populateChild(populateAll=True) + + @property + def parent(self): + """Parent of the node, or None if the node is a root + + :rtype: Hdf5Node + """ + return self.__parent + + def setParent(self, parent): + """Redefine the parent of the node. + + It does not set the node as the children of the new parent. + + :param Hdf5Node parent: The new parent + """ + self.__parent = parent + + def appendChild(self, child): + """Append a child to the node. + + It does not update the parent of the child. + + :param Hdf5Node child: Child to append to the node. + """ + self.__initChild() + self.__child.append(child) + + def removeChildAtIndex(self, index): + """Remove a child at an index of the children list. + + The child is removed and returned. + + :param int index: Index in the child list. + :rtype: Hdf5Node + :raises: IndexError if list is empty or index is out of range. + """ + self.__initChild() + return self.__child.pop(index) + + def insertChild(self, index, child): + """ + Insert a child at a specific index of the child list. + + It does not update the parent of the child. + + :param int index: Index in the child list. + :param Hdf5Node child: Child to insert in the child list. + """ + self.__initChild() + self.__child.insert(index, child) + + def indexOfChild(self, child): + """ + Returns the index of the child in the child list of this node. + + :param Hdf5Node child: Child to find + :raises: ValueError if the value is not present. + """ + self.__initChild() + return self.__child.index(child) + + def hasChildren(self): + """Returns true if the node contains children. + + :rtype: bool + """ + return self.childCount() > 0 + + def childCount(self): + """Returns the number of child in this node. + + :rtype: int + """ + if self.__child is not None: + return len(self.__child) + return self._expectedChildCount() + + def child(self, index): + """Return the child at an expected index. + + :param int index: Index of the child in the child list of the node + :rtype: Hdf5Node + """ + self.__initChild() + return self.__child[index] + + def __initChild(self): + """Init the child of the node in case the list was lazy loaded.""" + if self.__child is None: + self.__child = [] + self._populateChild() + + def _expectedChildCount(self): + """Returns the expected count of children + + :rtype: int + """ + return 0 + + def _populateChild(self, populateAll=False): + """Recurse through an HDF5 structure to append groups an datasets + into the tree model. + + Overwrite it to implement the initialisation of child of the node. + """ + pass + + def dataName(self, role): + """Data for the name column + + Overwrite it to implement the content of the 'name' column. + + :rtype: qt.QVariant + """ + return None + + def dataType(self, role): + """Data for the type column + + Overwrite it to implement the content of the 'type' column. + + :rtype: qt.QVariant + """ + return None + + def dataShape(self, role): + """Data for the shape column + + Overwrite it to implement the content of the 'shape' column. + + :rtype: qt.QVariant + """ + return None + + def dataValue(self, role): + """Data for the value column + + Overwrite it to implement the content of the 'value' column. + + :rtype: qt.QVariant + """ + return None + + def dataDescription(self, role): + """Data for the description column + + Overwrite it to implement the content of the 'description' column. + + :rtype: qt.QVariant + """ + return None + + def dataNode(self, role): + """Data for the node column + + Overwrite it to implement the content of the 'node' column. + + :rtype: qt.QVariant + """ + return None diff --git a/silx/gui/hdf5/Hdf5TreeModel.py b/silx/gui/hdf5/Hdf5TreeModel.py new file mode 100644 index 0000000..fb5de06 --- /dev/null +++ b/silx/gui/hdf5/Hdf5TreeModel.py @@ -0,0 +1,581 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "19/12/2016" + + +import os +import logging +from .. import qt +from .. import icons +from .Hdf5Node import Hdf5Node +from .Hdf5Item import Hdf5Item +from .Hdf5LoadingItem import Hdf5LoadingItem +from . import _utils +from ... import io as silx_io + +_logger = logging.getLogger(__name__) + +"""Helpers to take care of None objects as signal parameters. +PySide crash if a signal with a None parameter is emitted between threads. +""" +if qt.BINDING == 'PySide': + class _NoneWraper(object): + pass + _NoneWraperInstance = _NoneWraper() + + def _wrapNone(x): + """Wrap x if it is a None value, else returns x""" + if x is None: + return _NoneWraperInstance + else: + return x + + def _unwrapNone(x): + """Unwrap x as a None if a None was stored by `wrapNone`, else returns + x""" + if x is _NoneWraperInstance: + return None + else: + return x +else: + # Allow to fix None event params to avoid PySide crashes + def _wrapNone(x): + return x + + def _unwrapNone(x): + return x + + +class LoadingItemRunnable(qt.QRunnable): + """Runner to process item loading from a file""" + + class __Signals(qt.QObject): + """Signal holder""" + itemReady = qt.Signal(object, object, object) + runnerFinished = qt.Signal(object) + + def __init__(self, filename, item): + """Constructor + + :param LoadingItemWorker worker: Object holding data and signals + """ + super(LoadingItemRunnable, self).__init__() + self.filename = filename + self.oldItem = item + self.signals = self.__Signals() + + def setFile(self, filename, item): + self.filenames.append((filename, item)) + + @property + def itemReady(self): + return self.signals.itemReady + + @property + def runnerFinished(self): + return self.signals.runnerFinished + + def __loadItemTree(self, oldItem, h5obj): + """Create an item tree used by the GUI from an h5py object. + + :param Hdf5Node oldItem: The current item displayed the GUI + :param h5py.File h5obj: The h5py object to display in the GUI + :rtpye: Hdf5Node + """ + if silx_io.is_file(h5obj): + text = os.path.basename(h5obj.filename) + else: + filename = os.path.basename(h5obj.file.filename) + path = h5obj.name + text = "%s::%s" % (filename, path) + item = Hdf5Item(text=text, obj=h5obj, parent=oldItem.parent, populateAll=True) + return item + + @qt.Slot() + def run(self): + """Process the file loading. The worker is used as holder + of the data and the signal. The result is sent as a signal. + """ + try: + h5file = silx_io.open(self.filename) + newItem = self.__loadItemTree(self.oldItem, h5file) + error = None + except IOError as e: + # Should be logged + error = e + newItem = None + + # Take care of None value in case of PySide + newItem = _wrapNone(newItem) + error = _wrapNone(error) + self.itemReady.emit(self.oldItem, newItem, error) + self.runnerFinished.emit(self) + + def autoDelete(self): + return True + + +class Hdf5TreeModel(qt.QAbstractItemModel): + """Tree model storing a list of :class:`h5py.File` like objects. + + The main column display the :class:`h5py.File` list and there hierarchy. + Other columns display information on node hierarchy. + """ + + H5PY_ITEM_ROLE = qt.Qt.UserRole + """Role to reach h5py item from an item index""" + + H5PY_OBJECT_ROLE = qt.Qt.UserRole + 1 + """Role to reach h5py object from an item index""" + + USER_ROLE = qt.Qt.UserRole + 2 + """Start of range of available user role for derivative models""" + + NAME_COLUMN = 0 + """Column id containing HDF5 node names""" + + TYPE_COLUMN = 1 + """Column id containing HDF5 dataset types""" + + SHAPE_COLUMN = 2 + """Column id containing HDF5 dataset shapes""" + + VALUE_COLUMN = 3 + """Column id containing HDF5 dataset values""" + + DESCRIPTION_COLUMN = 4 + """Column id containing HDF5 node description/title/message""" + + NODE_COLUMN = 5 + """Column id containing HDF5 node type""" + + COLUMN_IDS = [ + NAME_COLUMN, + TYPE_COLUMN, + SHAPE_COLUMN, + VALUE_COLUMN, + DESCRIPTION_COLUMN, + NODE_COLUMN, + ] + """List of logical columns available""" + + def __init__(self, parent=None): + super(Hdf5TreeModel, self).__init__(parent) + + self.treeView = parent + self.header_labels = [None] * 6 + self.header_labels[self.NAME_COLUMN] = 'Name' + self.header_labels[self.TYPE_COLUMN] = 'Type' + self.header_labels[self.SHAPE_COLUMN] = 'Shape' + self.header_labels[self.VALUE_COLUMN] = 'Value' + self.header_labels[self.DESCRIPTION_COLUMN] = 'Description' + self.header_labels[self.NODE_COLUMN] = 'Node' + + # Create items + self.__root = Hdf5Node() + self.__fileDropEnabled = True + self.__fileMoveEnabled = True + + self.__animatedIcon = icons.getWaitIcon() + self.__animatedIcon.iconChanged.connect(self.__updateLoadingItems) + self.__runnerSet = set([]) + + # store used icons to avoid to avoid the cache to release it + self.__icons = [] + self.__icons.append(icons.getQIcon("item-0dim")) + self.__icons.append(icons.getQIcon("item-1dim")) + self.__icons.append(icons.getQIcon("item-2dim")) + self.__icons.append(icons.getQIcon("item-3dim")) + self.__icons.append(icons.getQIcon("item-ndim")) + self.__icons.append(icons.getQIcon("item-object")) + + def __updateLoadingItems(self, icon): + for i in range(self.__root.childCount()): + item = self.__root.child(i) + if isinstance(item, Hdf5LoadingItem): + index1 = self.index(i, 0, qt.QModelIndex()) + index2 = self.index(i, self.columnCount() - 1, qt.QModelIndex()) + self.dataChanged.emit(index1, index2) + + def __itemReady(self, oldItem, newItem, error): + """Called at the end of a concurent file loading, when the loading + item is ready. AN error is defined if an exception occured when + loading the newItem . + + :param Hdf5Node oldItem: current displayed item + :param Hdf5Node newItem: item loaded, or None if error is defined + :param Exception error: An exception, or None if newItem is defined + """ + # Take care of None value in case of PySide + newItem = _unwrapNone(newItem) + error = _unwrapNone(error) + row = self.__root.indexOfChild(oldItem) + rootIndex = qt.QModelIndex() + self.beginRemoveRows(rootIndex, row, row) + self.__root.removeChildAtIndex(row) + self.endRemoveRows() + if newItem is not None: + self.beginInsertRows(rootIndex, row, row) + self.__root.insertChild(row, newItem) + self.endInsertRows() + # FIXME the error must be displayed + + def isFileDropEnabled(self): + return self.__fileDropEnabled + + def setFileDropEnabled(self, enabled): + self.__fileDropEnabled = enabled + + fileDropEnabled = qt.Property(bool, isFileDropEnabled, setFileDropEnabled) + """Property to enable/disable file dropping in the model.""" + + def isFileMoveEnabled(self): + return self.__fileMoveEnabled + + def setFileMoveEnabled(self, enabled): + self.__fileMoveEnabled = enabled + + fileMoveEnabled = qt.Property(bool, isFileMoveEnabled, setFileMoveEnabled) + """Property to enable/disable drag-and-drop of files to + change the ordering in the model.""" + + def supportedDropActions(self): + if self.__fileMoveEnabled or self.__fileDropEnabled: + return qt.Qt.CopyAction | qt.Qt.MoveAction + else: + return 0 + + def mimeTypes(self): + if self.__fileMoveEnabled: + return [_utils.Hdf5NodeMimeData.MIME_TYPE] + else: + return [] + + def mimeData(self, indexes): + """ + Returns an object that contains serialized items of data corresponding + to the list of indexes specified. + + :param list(qt.QModelIndex) indexes: List of indexes + :rtype: qt.QMimeData + """ + if not self.__fileMoveEnabled or len(indexes) == 0: + return None + + indexes = [i for i in indexes if i.column() == 0] + if len(indexes) > 1: + raise NotImplementedError("Drag of multi rows is not implemented") + if len(indexes) == 0: + raise NotImplementedError("Drag of cell is not implemented") + + node = self.nodeFromIndex(indexes[0]) + mimeData = _utils.Hdf5NodeMimeData(node) + return mimeData + + def flags(self, index): + defaultFlags = qt.QAbstractItemModel.flags(self, index) + + if index.isValid(): + node = self.nodeFromIndex(index) + if self.__fileMoveEnabled and node.parent is self.__root: + # that's a root + return qt.Qt.ItemIsDragEnabled | defaultFlags + return defaultFlags + elif self.__fileDropEnabled or self.__fileMoveEnabled: + return qt.Qt.ItemIsDropEnabled | defaultFlags + else: + return defaultFlags + + def dropMimeData(self, mimedata, action, row, column, parentIndex): + if action == qt.Qt.IgnoreAction: + return True + + if self.__fileMoveEnabled and mimedata.hasFormat(_utils.Hdf5NodeMimeData.MIME_TYPE): + dragNode = mimedata.node() + parentNode = self.nodeFromIndex(parentIndex) + if parentNode is not dragNode.parent: + return False + + if row == -1: + # append to the parent + row = parentNode.childCount() + else: + # insert at row + pass + + dragNodeParent = dragNode.parent + sourceRow = dragNodeParent.indexOfChild(dragNode) + self.moveRow(parentIndex, sourceRow, parentIndex, row) + return True + + if self.__fileDropEnabled and mimedata.hasFormat("text/uri-list"): + + parentNode = self.nodeFromIndex(parentIndex) + if parentNode is not self.__root: + while(parentNode is not self.__root): + node = parentNode + parentNode = node.parent + row = parentNode.indexOfChild(node) + else: + if row == -1: + row = self.__root.childCount() + + messages = [] + for url in mimedata.urls(): + try: + self.insertFileAsync(url.toLocalFile(), row) + row += 1 + except IOError as e: + messages.append(e.args[0]) + if len(messages) > 0: + title = "Error occurred when loading files" + message = "<html>%s:<ul><li>%s</li><ul></html>" % (title, "</li><li>".join(messages)) + qt.QMessageBox.critical(None, title, message) + return True + + return False + + def headerData(self, section, orientation, role=qt.Qt.DisplayRole): + if orientation == qt.Qt.Horizontal: + if role in [qt.Qt.DisplayRole, qt.Qt.EditRole]: + return self.header_labels[section] + return None + + def insertNode(self, row, node): + if row == -1: + row = self.__root.childCount() + self.beginInsertRows(qt.QModelIndex(), row, row) + self.__root.insertChild(row, node) + self.endInsertRows() + + def moveRow(self, sourceParentIndex, sourceRow, destinationParentIndex, destinationRow): + if sourceRow == destinationRow or sourceRow == destinationRow - 1: + # abort move, same place + return + return self.moveRows(sourceParentIndex, sourceRow, 1, destinationParentIndex, destinationRow) + + def moveRows(self, sourceParentIndex, sourceRow, count, destinationParentIndex, destinationRow): + self.beginMoveRows(sourceParentIndex, sourceRow, sourceRow, destinationParentIndex, destinationRow) + sourceNode = self.nodeFromIndex(sourceParentIndex) + destinationNode = self.nodeFromIndex(destinationParentIndex) + + if sourceNode is destinationNode and sourceRow < destinationRow: + item = sourceNode.child(sourceRow) + destinationNode.insertChild(destinationRow, item) + sourceNode.removeChildAtIndex(sourceRow) + else: + item = sourceNode.removeChildAtIndex(sourceRow) + destinationNode.insertChild(destinationRow, item) + + self.endMoveRows() + return True + + def index(self, row, column, parent=qt.QModelIndex()): + try: + node = self.nodeFromIndex(parent) + return self.createIndex(row, column, node.child(row)) + except IndexError: + return qt.QModelIndex() + + def data(self, index, role=qt.Qt.DisplayRole): + node = self.nodeFromIndex(index) + + if role == self.H5PY_ITEM_ROLE: + return node + + if role == self.H5PY_OBJECT_ROLE: + return node.obj + + if index.column() == self.NAME_COLUMN: + return node.dataName(role) + elif index.column() == self.TYPE_COLUMN: + return node.dataType(role) + elif index.column() == self.SHAPE_COLUMN: + return node.dataShape(role) + elif index.column() == self.VALUE_COLUMN: + return node.dataValue(role) + elif index.column() == self.DESCRIPTION_COLUMN: + return node.dataDescription(role) + elif index.column() == self.NODE_COLUMN: + return node.dataNode(role) + else: + return None + + def columnCount(self, parent=qt.QModelIndex()): + return len(self.header_labels) + + def hasChildren(self, parent=qt.QModelIndex()): + node = self.nodeFromIndex(parent) + if node is None: + return 0 + return node.hasChildren() + + def rowCount(self, parent=qt.QModelIndex()): + node = self.nodeFromIndex(parent) + if node is None: + return 0 + return node.childCount() + + def parent(self, child): + if not child.isValid(): + return qt.QModelIndex() + + node = self.nodeFromIndex(child) + + if node is None: + return qt.QModelIndex() + + parent = node.parent + + if parent is None: + return qt.QModelIndex() + + grandparent = parent.parent + if grandparent is None: + return qt.QModelIndex() + row = grandparent.indexOfChild(parent) + + assert row != - 1 + return self.createIndex(row, 0, parent) + + def nodeFromIndex(self, index): + return index.internalPointer() if index.isValid() else self.__root + + def synchronizeIndex(self, index): + """ + Synchronize a file a given its index. + + Basically close it and load it again. + + :param qt.QModelIndex index: Index of the item to update + """ + node = self.nodeFromIndex(index) + if node.parent is not self.__root: + return + + self.removeIndex(index) + filename = node.obj.filename + node.obj.close() + self.insertFileAsync(filename, index.row()) + + def synchronizeH5pyObject(self, h5pyObject): + """ + Synchronize a h5py object in all the tree. + + Basically close it and load it again. + + :param h5py.File h5pyObject: A :class:`h5py.File` object. + """ + index = 0 + while index < self.__root.childCount(): + item = self.__root.child(index) + if item.obj is h5pyObject: + qindex = self.index(index, 0, qt.QModelIndex()) + self.synchronizeIndex(qindex) + else: + index += 1 + + def removeIndex(self, index): + """ + Remove an item from the model using its index. + + :param qt.QModelIndex index: Index of the item to remove + """ + node = self.nodeFromIndex(index) + if node.parent is not self.__root: + return + self.beginRemoveRows(qt.QModelIndex(), index.row(), index.row()) + self.__root.removeChildAtIndex(index.row()) + self.endRemoveRows() + + def removeH5pyObject(self, h5pyObject): + """ + Remove an item from the model using the holding h5py object. + It can remove more than one item. + + :param h5py.File h5pyObject: A :class:`h5py.File` object. + """ + index = 0 + while index < self.__root.childCount(): + item = self.__root.child(index) + if item.obj is h5pyObject: + qindex = self.index(index, 0, qt.QModelIndex()) + self.removeIndex(qindex) + else: + index += 1 + + def insertH5pyObject(self, h5pyObject, text=None, row=-1): + """Append an HDF5 object from h5py to the tree. + + :param h5pyObject: File handle/descriptor for a :class:`h5py.File` + or any other class of h5py file structure. + """ + if text is None: + if silx_io.is_file(h5pyObject): + text = os.path.basename(h5pyObject.filename) + else: + filename = os.path.basename(h5pyObject.file.filename) + path = h5pyObject.name + text = "%s::%s" % (filename, path) + if row == -1: + row = self.__root.childCount() + self.insertNode(row, Hdf5Item(text=text, obj=h5pyObject, parent=self.__root)) + + def insertFileAsync(self, filename, row=-1): + if not os.path.isfile(filename): + raise IOError("Filename '%s' must be a file path" % filename) + + # create temporary item + text = os.path.basename(filename) + item = Hdf5LoadingItem(text=text, parent=self.__root, animatedIcon=self.__animatedIcon) + self.insertNode(row, item) + + # start loading the real one + runnable = LoadingItemRunnable(filename, item) + runnable.itemReady.connect(self.__itemReady) + self.__runnerSet.add(runnable) + runnable.runnerFinished.connect(self.__releaseRunner) + qt.QThreadPool.globalInstance().start(runnable) + + def __releaseRunner(self, runner): + self.__runnerSet.remove(runner) + + def insertFile(self, filename, row=-1): + """Load a HDF5 file into the data model. + + :param filename: file path. + """ + try: + h5file = silx_io.open(filename) + self.insertH5pyObject(h5file, row=row) + except IOError: + _logger.debug("File '%s' can't be read.", filename, exc_info=True) + raise + + def appendFile(self, filename): + self.insertFile(filename, -1) diff --git a/silx/gui/hdf5/Hdf5TreeView.py b/silx/gui/hdf5/Hdf5TreeView.py new file mode 100644 index 0000000..09f6fcf --- /dev/null +++ b/silx/gui/hdf5/Hdf5TreeView.py @@ -0,0 +1,204 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "27/09/2016" + + +import logging +from .. import qt +from ...utils import weakref as silxweakref +from .Hdf5TreeModel import Hdf5TreeModel +from .Hdf5HeaderView import Hdf5HeaderView +from .NexusSortFilterProxyModel import NexusSortFilterProxyModel +from .Hdf5Item import Hdf5Item +from . import _utils + +_logger = logging.getLogger(__name__) + + +class Hdf5TreeView(qt.QTreeView): + """TreeView which allow to browse HDF5 file structure. + + It provides columns width auto-resizing and additional + signals. + + The default model is a :class:`NexusSortFilterProxyModel` sourcing + a :class:`Hdf5TreeModel`. The :class:`Hdf5TreeModel` is reachable using + :meth:`findHdf5TreeModel`. The default header is :class:`Hdf5HeaderView`. + + Context menu is managed by the :meth:`setContextMenuPolicy` with the value + Qt.CustomContextMenu. This policy must not be changed, otherwise context + menus will not work anymore. You can use :meth:`addContextMenuCallback` and + :meth:`removeContextMenuCallback` to add your custum actions according + to the selected objects. + """ + def __init__(self, parent=None): + """ + Constructor + + :param parent qt.QWidget: The parent widget + """ + qt.QTreeView.__init__(self, parent) + + model = Hdf5TreeModel(self) + proxy_model = NexusSortFilterProxyModel(self) + proxy_model.setSourceModel(model) + self.setModel(proxy_model) + + self.setHeader(Hdf5HeaderView(qt.Qt.Horizontal, self)) + self.setSelectionBehavior(qt.QAbstractItemView.SelectRows) + self.sortByColumn(0, qt.Qt.AscendingOrder) + # optimise the rendering + self.setUniformRowHeights(True) + + self.setIconSize(qt.QSize(16, 16)) + self.setAcceptDrops(True) + self.setDragEnabled(True) + self.setDragDropMode(qt.QAbstractItemView.DragDrop) + self.showDropIndicator() + + self.__context_menu_callbacks = silxweakref.WeakList() + self.setContextMenuPolicy(qt.Qt.CustomContextMenu) + self.customContextMenuRequested.connect(self._createContextMenu) + + def __removeContextMenuProxies(self, ref): + """Callback to remove dead proxy from the list""" + self.__context_menu_callbacks.remove(ref) + + def _createContextMenu(self, pos): + """ + Create context menu. + + :param pos qt.QPoint: Position of the context menu + """ + actions = [] + + menu = qt.QMenu(self) + + hovered_index = self.indexAt(pos) + hovered_node = self.model().data(hovered_index, Hdf5TreeModel.H5PY_ITEM_ROLE) + if hovered_node is None or not isinstance(hovered_node, Hdf5Item): + return + + hovered_object = _utils.H5Node(hovered_node) + event = _utils.Hdf5ContextMenuEvent(self, menu, hovered_object) + + for callback in self.__context_menu_callbacks: + try: + callback(event) + except KeyboardInterrupt: + raise + except: + # make sure no user callback crash the application + _logger.error("Error while calling callback", exc_info=True) + pass + + if len(menu.children()) > 0: + for action in actions: + menu.addAction(action) + menu.popup(self.viewport().mapToGlobal(pos)) + + def addContextMenuCallback(self, callback): + """Register a context menu callback. + + The callback will be called when a context menu is requested with the + treeview and the list of selected h5py objects in parameters. The + callback must return a list of :class:`qt.QAction` object. + + Callbacks are stored as saferef. The object must store a reference by + itself. + """ + self.__context_menu_callbacks.append(callback) + + def removeContextMenuCallback(self, callback): + """Unregister a context menu callback""" + self.__context_menu_callbacks.remove(callback) + + def findHdf5TreeModel(self): + """Find the Hdf5TreeModel from the stack of model filters. + + :returns: A Hdf5TreeModel, else None + :rtype: Hdf5TreeModel + """ + model = self.model() + while model is not None: + if isinstance(model, qt.QAbstractProxyModel): + model = model.sourceModel() + else: + break + if model is None: + return None + if isinstance(model, Hdf5TreeModel): + return model + else: + return None + + def dragEnterEvent(self, event): + model = self.findHdf5TreeModel() + if model is not None and model.isFileDropEnabled() and event.mimeData().hasFormat("text/uri-list"): + self.setState(qt.QAbstractItemView.DraggingState) + event.accept() + else: + qt.QTreeView.dragEnterEvent(self, event) + + def dragMoveEvent(self, event): + model = self.findHdf5TreeModel() + if model is not None and model.isFileDropEnabled() and event.mimeData().hasFormat("text/uri-list"): + event.setDropAction(qt.Qt.CopyAction) + event.accept() + else: + qt.QTreeView.dragMoveEvent(self, event) + + def selectedH5Nodes(self, ignoreBrokenLinks=True): + """Returns selected h5py objects like :class:`h5py.File`, + :class:`h5py.Group`, :class:`h5py.Dataset` or mimicked objects. + + :param ignoreBrokenLinks bool: Returns objects which are not not + broken links. + :rtype: iterator(:class:`_utils.H5Node`) + """ + for index in self.selectedIndexes(): + if index.column() != 0: + continue + item = self.model().data(index, Hdf5TreeModel.H5PY_ITEM_ROLE) + if item is None: + continue + if isinstance(item, Hdf5Item): + if ignoreBrokenLinks and item.isBrokenObj(): + continue + yield _utils.H5Node(item) + + def mousePressEvent(self, event): + """Override mousePressEvent to provide a consistante compatible API + between Qt4 and Qt5 + """ + super(Hdf5TreeView, self).mousePressEvent(event) + if event.button() != qt.Qt.LeftButton: + # Qt5 only sends itemClicked on left button mouse click + if qt.qVersion() > "5": + qindex = self.indexAt(event.pos()) + self.clicked.emit(qindex) diff --git a/silx/gui/hdf5/NexusSortFilterProxyModel.py b/silx/gui/hdf5/NexusSortFilterProxyModel.py new file mode 100644 index 0000000..9a4268c --- /dev/null +++ b/silx/gui/hdf5/NexusSortFilterProxyModel.py @@ -0,0 +1,152 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "12/04/2017" + + +import logging +import re +import numpy +from .. import qt +from .Hdf5TreeModel import Hdf5TreeModel + +_logger = logging.getLogger(__name__) + +try: + import h5py +except ImportError as e: + _logger.error("Module %s requires h5py", __name__) + raise e + +_logger = logging.getLogger(__name__) + + +class NexusSortFilterProxyModel(qt.QSortFilterProxyModel): + """Try to sort items according to Nexus structure. Else sort by name.""" + + def __init__(self, parent=None): + qt.QSortFilterProxyModel.__init__(self, parent) + self.__split = re.compile("(\\d+|\\D+)") + + def lessThan(self, sourceLeft, sourceRight): + """Returns True if the value of the item referred to by the given + index `sourceLeft` is less than the value of the item referred to by + the given index `sourceRight`, otherwise returns false. + + :param qt.QModelIndex sourceLeft: + :param qt.QModelIndex sourceRight: + :rtype: bool + """ + if sourceLeft.column() != Hdf5TreeModel.NAME_COLUMN: + return super(NexusSortFilterProxyModel, self).lessThan( + sourceLeft, sourceRight) + + # Do not sort child of root (files) + if sourceLeft.parent() == qt.QModelIndex(): + return sourceLeft.row() < sourceRight.row() + + left = self.sourceModel().data(sourceLeft, Hdf5TreeModel.H5PY_ITEM_ROLE) + right = self.sourceModel().data(sourceRight, Hdf5TreeModel.H5PY_ITEM_ROLE) + + if self.__isNXentry(left) and self.__isNXentry(right): + less = self.childDatasetLessThan(left, right, "start_time") + if less is not None: + return less + less = self.childDatasetLessThan(left, right, "end_time") + if less is not None: + return less + + left = self.sourceModel().data(sourceLeft, qt.Qt.DisplayRole) + right = self.sourceModel().data(sourceRight, qt.Qt.DisplayRole) + return self.nameLessThan(left, right) + + def __isNXentry(self, node): + """Returns true if the node is an NXentry""" + if not issubclass(node.h5pyClass, h5py.Group): + return False + nxClass = node.obj.attrs.get("NX_class", None) + return nxClass == "NXentry" + + def getWordsAndNumbers(self, name): + """ + Returns a list of words and integers composing the name. + + An input `"aaa10bbb50.30"` will return + `["aaa", 10, "bbb", 50, ".", 30]`. + + :param str name: A name + :rtype: list + """ + words = self.__split.findall(name) + result = [] + for i in words: + if i[0].isdigit(): + i = int(i) + result.append(i) + return result + + def nameLessThan(self, left, right): + """Returns True if the left string is less than the right string. + + Number composing the names are compared as integers, as result "name2" + is smaller than "name10". + + :param str left: A string + :param str right: A string + :rtype: bool + """ + leftList = self.getWordsAndNumbers(left) + rightList = self.getWordsAndNumbers(right) + try: + return leftList < rightList + except TypeError: + # Back to string comparison if list are not type consistent + return left < right + + def childDatasetLessThan(self, left, right, childName): + """ + Reach the same children name of two items and compare their values. + + Returns True if the left one is smaller than the right one. + + :param Hdf5Item left: An item + :param Hdf5Item right: An item + :param str childName: Name of the children to search. Returns None if + the children is not found. + :rtype: bool + """ + try: + left_time = left.obj[childName][()] + right_time = right.obj[childName][()] + if isinstance(left_time, numpy.ndarray): + return left_time[0] < right_time[0] + return left_time < right_time + except KeyboardInterrupt: + raise + except Exception as e: + _logger.debug("Exception occurred", exc_info=True) + return None diff --git a/silx/gui/hdf5/__init__.py b/silx/gui/hdf5/__init__.py new file mode 100644 index 0000000..1b5a602 --- /dev/null +++ b/silx/gui/hdf5/__init__.py @@ -0,0 +1,44 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This package provides a set of Qt widgets for displaying content relative to +HDF5 format. + +.. note:: + + This package depends on *h5py*. +""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "23/09/2016" + + +from .Hdf5TreeView import Hdf5TreeView # noqa +from ._utils import H5Node +from ._utils import Hdf5ContextMenuEvent # noqa +from .NexusSortFilterProxyModel import NexusSortFilterProxyModel # noqa +from .Hdf5TreeModel import Hdf5TreeModel # noqa + +__all__ = ['Hdf5TreeView', 'H5Node', 'Hdf5ContextMenuEvent', 'NexusSortFilterProxyModel', 'Hdf5TreeModel'] diff --git a/silx/gui/hdf5/_utils.py b/silx/gui/hdf5/_utils.py new file mode 100644 index 0000000..af9c79f --- /dev/null +++ b/silx/gui/hdf5/_utils.py @@ -0,0 +1,247 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This package provides a set of helper class and function used by the +package `silx.gui.hdf5` package. +""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "26/04/2017" + + +import logging +import numpy +from .. import qt +import silx.io.utils +from silx.utils.html import escape + +_logger = logging.getLogger(__name__) + +try: + import h5py +except ImportError as e: + _logger.error("Module %s requires h5py", __name__) + raise e + + +class Hdf5ContextMenuEvent(object): + """Hold information provided to context menu callbacks.""" + + def __init__(self, source, menu, hoveredObject): + """ + Constructor + + :param QWidget source: Widget source + :param QMenu menu: Context menu which will be displayed + :param H5Node hoveredObject: Hovered H5 node + """ + self.__source = source + self.__menu = menu + self.__hoveredObject = hoveredObject + + def source(self): + """Source of the event + + :rtype: Hdf5TreeView + """ + return self.__source + + def menu(self): + """Menu which will be displayed + + :rtype: qt.QMenu + """ + return self.__menu + + def hoveredObject(self): + """Item content hovered by the mouse when the context menu was + requested + + :rtype: H5Node + """ + return self.__hoveredObject + + +def htmlFromDict(dictionary, title=None): + """Generate a readable HTML from a dictionary + + :param dict dictionary: A Dictionary + :rtype: str + """ + result = """<html> + <head> + <style type="text/css"> + ul { -qt-list-indent: 0; list-style: none; } + li > b {display: inline-block; min-width: 4em; font-weight: bold; } + </style> + </head> + <body> + """ + if title is not None: + result += "<b>%s</b>" % escape(title) + result += "<ul>" + for key, value in dictionary.items(): + result += "<li><b>%s</b>: %s</li>" % (escape(key), escape(value)) + result += "</ul>" + result += "</body></html>" + return result + + +class Hdf5NodeMimeData(qt.QMimeData): + """Mimedata class to identify an internal drag and drop of a Hdf5Node.""" + + MIME_TYPE = "application/x-internal-h5py-node" + + def __init__(self, node=None): + qt.QMimeData.__init__(self) + self.__node = node + self.setData(self.MIME_TYPE, "".encode(encoding='utf-8')) + + def node(self): + return self.__node + + +class H5Node(object): + """Adapter over an h5py object to provide missing informations from h5py + nodes, like internal node path and filename (which are not provided by + :mod:`h5py` for soft and external links). + + It also provides an abstraction to reach node type for mimicked h5py + objects. + """ + + def __init__(self, h5py_item=None): + """Constructor + + :param Hdf5Item h5py_item: An Hdf5Item + """ + self.__h5py_object = h5py_item.obj + self.__h5py_item = h5py_item + + def __getattr__(self, name): + return object.__getattribute__(self.__h5py_object, name) + + @property + def h5py_object(self): + """Returns the internal h5py node. + + :rtype: h5py.File or h5py.Group or h5py.Dataset + """ + return self.__h5py_object + + @property + def ntype(self): + """Returns the node type, as an h5py class. + + :rtype: + :class:`h5py.File`, :class:`h5py.Group` or :class:`h5py.Dataset` + """ + return silx.io.utils.get_h5py_class(self.__h5py_object) + + @property + def basename(self): + """Returns the basename of this h5py node. It is the last identifier of + the path. + + :rtype: str + """ + return self.__h5py_object.name.split("/")[-1] + + @property + def local_name(self): + """Returns the local path of this h5py node. + + For links, this path is not equal to the h5py one. + + :rtype: str + """ + if self.__h5py_item is None: + raise RuntimeError("h5py_item is not defined") + + result = [] + item = self.__h5py_item + while item is not None: + if issubclass(item.h5pyClass, h5py.File): + break + result.append(item.basename) + item = item.parent + if item is None: + raise RuntimeError("The item does not have parent holding h5py.File") + if result == []: + return "/" + result.append("") + result.reverse() + return "/".join(result) + + def __file_item(self): + """Returns the parent item holding the :class:`h5py.File` object + + :rtype: h5py.File + :raises RuntimeException: If no file are found + """ + item = self.__h5py_item + while item is not None: + if issubclass(item.h5pyClass, h5py.File): + return item + item = item.parent + raise RuntimeError("The item does not have parent holding h5py.File") + + @property + def local_file(self): + """Returns the local :class:`h5py.File` object. + + For path containing external links, this file is not equal to the h5py + one. + + :rtype: h5py.File + :raises RuntimeException: If no file are found + """ + item = self.__file_item() + return item.obj + + @property + def local_filename(self): + """Returns the local filename of the h5py node. + + For path containing external links, this path is not equal to the + filename provided by h5py. + + :rtype: str + :raises RuntimeException: If no file are found + """ + return self.local_file.filename + + @property + def local_basename(self): + """Returns the local filename of the h5py node. + + For path containing links, this basename can be different than the + basename provided by h5py. + + :rtype: str + """ + if issubclass(self.__h5py_item.h5pyClass, h5py.File): + return "" + return self.__h5py_item.basename diff --git a/silx/gui/hdf5/setup.py b/silx/gui/hdf5/setup.py new file mode 100644 index 0000000..786a851 --- /dev/null +++ b/silx/gui/hdf5/setup.py @@ -0,0 +1,41 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "28/09/2016" + + +from numpy.distutils.misc_util import Configuration + + +def configuration(parent_package='', top_path=None): + config = Configuration('hdf5', parent_package, top_path) + config.add_subpackage('test') + return config + + +if __name__ == "__main__": + from numpy.distutils.core import setup + setup(configuration=configuration) diff --git a/silx/gui/hdf5/test/__init__.py b/silx/gui/hdf5/test/__init__.py new file mode 100644 index 0000000..3000d96 --- /dev/null +++ b/silx/gui/hdf5/test/__init__.py @@ -0,0 +1,39 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +import unittest + +from . import test_hdf5 + + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "28/09/2016" + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTests( + [test_hdf5.suite()]) + return test_suite diff --git a/silx/gui/hdf5/test/_mock.py b/silx/gui/hdf5/test/_mock.py new file mode 100644 index 0000000..eada590 --- /dev/null +++ b/silx/gui/hdf5/test/_mock.py @@ -0,0 +1,130 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Mock for silx.gui.hdf5 module""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "12/04/2017" + + +import numpy +try: + import h5py +except ImportError: + h5py = None + + +class Node(object): + + def __init__(self, basename, parent, h5py_class): + self.basename = basename + self.h5py_class = h5py_class + self.attrs = {} + self.parent = parent + if parent is not None: + self.parent._add(self) + + @property + def name(self): + if self.parent is None: + return self.basename + if self.parent.name == "": + return self.basename + return self.parent.name + "/" + self.basename + + @property + def file(self): + if self.parent is None: + return self + return self.parent.file + + +class Group(Node): + """Mock an h5py Group""" + + def __init__(self, name, parent, h5py_class=h5py.Group): + super(Group, self).__init__(name, parent, h5py_class) + self.__items = {} + + def _add(self, node): + self.__items[node.basename] = node + + def __getitem__(self, key): + return self.__items[key] + + def __iter__(self): + for k in self.__items: + yield k + + def __len__(self): + return len(self.__items) + + def get(self, name, getclass=False, getlink=False): + result = self.__items[name] + if getclass: + return result.h5py_class + return result + + def create_dataset(self, name, data): + return Dataset(name, self, data) + + def create_group(self, name): + return Group(name, self) + + def create_NXentry(self, name): + group = Group(name, self) + group.attrs["NX_class"] = "NXentry" + return group + + +class File(Group): + """Mock an h5py File""" + + def __init__(self, filename): + super(File, self).__init__("", None, h5py.File) + self.filename = filename + + +class Dataset(Node): + """Mock an h5py Dataset""" + + def __init__(self, name, parent, value): + super(Dataset, self).__init__(name, parent, h5py.Dataset) + self.__value = value + self.shape = self.__value.shape + self.dtype = self.__value.dtype + self.size = self.__value.size + self.compression = None + self.compression_opts = None + + def __getitem__(self, key): + if not isinstance(self.__value, numpy.ndarray): + if key == tuple(): + return self.__value + elif key == Ellipsis: + return numpy.array(self.__value) + else: + raise ValueError("Bad key") + return self.__value[key] diff --git a/silx/gui/hdf5/test/test_hdf5.py b/silx/gui/hdf5/test/test_hdf5.py new file mode 100644 index 0000000..3bf4897 --- /dev/null +++ b/silx/gui/hdf5/test/test_hdf5.py @@ -0,0 +1,480 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Test for silx.gui.hdf5 module""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "12/04/2017" + + +import time +import os +import unittest +import tempfile +import numpy +from contextlib import contextmanager +from silx.gui import qt +from silx.gui.test.utils import TestCaseQt +from silx.gui import hdf5 +from . import _mock + +try: + import h5py +except ImportError: + h5py = None + + +_called = 0 + + +class _Holder(object): + def callback(self, *args, **kvargs): + _called += 1 + + +class TestHdf5TreeModel(TestCaseQt): + + def setUp(self): + super(TestHdf5TreeModel, self).setUp() + if h5py is None: + self.skipTest("h5py is not available") + + @contextmanager + def h5TempFile(self): + # create tmp file + fd, tmp_name = tempfile.mkstemp(suffix=".h5") + os.close(fd) + # create h5 data + h5file = h5py.File(tmp_name, "w") + g = h5file.create_group("arrays") + g.create_dataset("scalar", data=10) + h5file.close() + yield tmp_name + # clean up + os.unlink(tmp_name) + + def testCreate(self): + model = hdf5.Hdf5TreeModel() + self.assertIsNotNone(model) + + def testAppendFilename(self): + with self.h5TempFile() as filename: + model = hdf5.Hdf5TreeModel() + self.assertEquals(model.rowCount(qt.QModelIndex()), 0) + model.appendFile(filename) + self.assertEquals(model.rowCount(qt.QModelIndex()), 1) + # clean up + index = model.index(0, 0, qt.QModelIndex()) + h5File = model.data(index, hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE) + h5File.close() + + def testAppendBadFilename(self): + model = hdf5.Hdf5TreeModel() + self.assertRaises(IOError, model.appendFile, "#%$") + + def testInsertFilename(self): + with self.h5TempFile() as filename: + model = hdf5.Hdf5TreeModel() + self.assertEquals(model.rowCount(qt.QModelIndex()), 0) + model.insertFile(filename) + self.assertEquals(model.rowCount(qt.QModelIndex()), 1) + # clean up + index = model.index(0, 0, qt.QModelIndex()) + h5File = model.data(index, hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE) + h5File.close() + + def testInsertFilenameAsync(self): + with self.h5TempFile() as filename: + model = hdf5.Hdf5TreeModel() + self.assertEquals(model.rowCount(qt.QModelIndex()), 0) + model.insertFileAsync(filename) + index = model.index(0, 0, qt.QModelIndex()) + self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5LoadingItem.Hdf5LoadingItem) + time.sleep(0.1) + self.qapp.processEvents() + time.sleep(0.1) + index = model.index(0, 0, qt.QModelIndex()) + self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5Item.Hdf5Item) + self.assertEquals(model.rowCount(qt.QModelIndex()), 1) + # clean up + index = model.index(0, 0, qt.QModelIndex()) + h5File = model.data(index, hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE) + h5File.close() + + def testInsertObject(self): + h5 = _mock.File("/foo/bar/1.mock") + model = hdf5.Hdf5TreeModel() + self.assertEquals(model.rowCount(qt.QModelIndex()), 0) + model.insertH5pyObject(h5) + self.assertEquals(model.rowCount(qt.QModelIndex()), 1) + + def testRemoveObject(self): + h5 = _mock.File("/foo/bar/1.mock") + model = hdf5.Hdf5TreeModel() + self.assertEquals(model.rowCount(qt.QModelIndex()), 0) + model.insertH5pyObject(h5) + self.assertEquals(model.rowCount(qt.QModelIndex()), 1) + model.removeH5pyObject(h5) + self.assertEquals(model.rowCount(qt.QModelIndex()), 0) + + def testSynchronizeObject(self): + with self.h5TempFile() as filename: + h5 = h5py.File(filename) + model = hdf5.Hdf5TreeModel() + model.insertH5pyObject(h5) + self.assertEquals(model.rowCount(qt.QModelIndex()), 1) + index = model.index(0, 0, qt.QModelIndex()) + node1 = model.nodeFromIndex(index) + model.synchronizeH5pyObject(h5) + index = model.index(0, 0, qt.QModelIndex()) + node2 = model.nodeFromIndex(index) + self.assertIsNot(node1, node2) + # after sync + time.sleep(0.1) + self.qapp.processEvents() + time.sleep(0.1) + index = model.index(0, 0, qt.QModelIndex()) + self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5Item.Hdf5Item) + # clean up + index = model.index(0, 0, qt.QModelIndex()) + h5File = model.data(index, hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE) + h5File.close() + + def testFileMoveState(self): + model = hdf5.Hdf5TreeModel() + self.assertEquals(model.isFileMoveEnabled(), True) + model.setFileMoveEnabled(False) + self.assertEquals(model.isFileMoveEnabled(), False) + + def testFileDropState(self): + model = hdf5.Hdf5TreeModel() + self.assertEquals(model.isFileDropEnabled(), True) + model.setFileDropEnabled(False) + self.assertEquals(model.isFileDropEnabled(), False) + + def testSupportedDrop(self): + model = hdf5.Hdf5TreeModel() + self.assertNotEquals(model.supportedDropActions(), 0) + + model.setFileMoveEnabled(False) + model.setFileDropEnabled(False) + self.assertEquals(model.supportedDropActions(), 0) + + model.setFileMoveEnabled(False) + model.setFileDropEnabled(True) + self.assertNotEquals(model.supportedDropActions(), 0) + + model.setFileMoveEnabled(True) + model.setFileDropEnabled(False) + self.assertNotEquals(model.supportedDropActions(), 0) + + def testDropExternalFile(self): + with self.h5TempFile() as filename: + model = hdf5.Hdf5TreeModel() + mimeData = qt.QMimeData() + mimeData.setUrls([qt.QUrl.fromLocalFile(filename)]) + model.dropMimeData(mimeData, qt.Qt.CopyAction, 0, 0, qt.QModelIndex()) + self.assertEquals(model.rowCount(qt.QModelIndex()), 1) + # after sync + time.sleep(0.1) + self.qapp.processEvents() + time.sleep(0.1) + index = model.index(0, 0, qt.QModelIndex()) + self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5Item.Hdf5Item) + # clean up + index = model.index(0, 0, qt.QModelIndex()) + h5File = model.data(index, role=hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE) + h5File.close() + + def getRowDataAsDict(self, model, row): + displayed = {} + roles = [qt.Qt.DisplayRole, qt.Qt.DecorationRole, qt.Qt.ToolTipRole, qt.Qt.TextAlignmentRole] + for column in range(0, model.columnCount(qt.QModelIndex())): + index = model.index(0, column, qt.QModelIndex()) + for role in roles: + datum = model.data(index, role) + displayed[column, role] = datum + return displayed + + def getItemName(self, model, row): + index = model.index(row, hdf5.Hdf5TreeModel.NAME_COLUMN, qt.QModelIndex()) + return model.data(index, qt.Qt.DisplayRole) + + def testFileData(self): + h5 = _mock.File("/foo/bar/1.mock") + model = hdf5.Hdf5TreeModel() + model.insertH5pyObject(h5) + displayed = self.getRowDataAsDict(model, row=0) + self.assertEquals(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DisplayRole], "1.mock") + self.assertIsInstance(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DecorationRole], qt.QIcon) + self.assertEquals(displayed[hdf5.Hdf5TreeModel.TYPE_COLUMN, qt.Qt.DisplayRole], "") + self.assertEquals(displayed[hdf5.Hdf5TreeModel.SHAPE_COLUMN, qt.Qt.DisplayRole], "") + self.assertEquals(displayed[hdf5.Hdf5TreeModel.VALUE_COLUMN, qt.Qt.DisplayRole], "") + self.assertEquals(displayed[hdf5.Hdf5TreeModel.DESCRIPTION_COLUMN, qt.Qt.DisplayRole], "") + self.assertEquals(displayed[hdf5.Hdf5TreeModel.NODE_COLUMN, qt.Qt.DisplayRole], "File") + + def testGroupData(self): + h5 = _mock.File("/foo/bar/1.mock") + d = h5.create_group("foo") + d.attrs["desc"] = "fooo" + + model = hdf5.Hdf5TreeModel() + model.insertH5pyObject(d) + displayed = self.getRowDataAsDict(model, row=0) + self.assertEquals(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DisplayRole], "1.mock::foo") + self.assertIsInstance(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DecorationRole], qt.QIcon) + self.assertEquals(displayed[hdf5.Hdf5TreeModel.TYPE_COLUMN, qt.Qt.DisplayRole], "") + self.assertEquals(displayed[hdf5.Hdf5TreeModel.SHAPE_COLUMN, qt.Qt.DisplayRole], "") + self.assertEquals(displayed[hdf5.Hdf5TreeModel.VALUE_COLUMN, qt.Qt.DisplayRole], "") + self.assertEquals(displayed[hdf5.Hdf5TreeModel.DESCRIPTION_COLUMN, qt.Qt.DisplayRole], "fooo") + self.assertEquals(displayed[hdf5.Hdf5TreeModel.NODE_COLUMN, qt.Qt.DisplayRole], "Group") + + def testDatasetData(self): + h5 = _mock.File("/foo/bar/1.mock") + value = numpy.array([1, 2, 3]) + d = h5.create_dataset("foo", value) + + model = hdf5.Hdf5TreeModel() + model.insertH5pyObject(d) + displayed = self.getRowDataAsDict(model, row=0) + self.assertEquals(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DisplayRole], "1.mock::foo") + self.assertIsInstance(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DecorationRole], qt.QIcon) + self.assertEquals(displayed[hdf5.Hdf5TreeModel.TYPE_COLUMN, qt.Qt.DisplayRole], value.dtype.name) + self.assertEquals(displayed[hdf5.Hdf5TreeModel.SHAPE_COLUMN, qt.Qt.DisplayRole], "3") + self.assertEquals(displayed[hdf5.Hdf5TreeModel.VALUE_COLUMN, qt.Qt.DisplayRole], "[1 2 3]") + self.assertEquals(displayed[hdf5.Hdf5TreeModel.DESCRIPTION_COLUMN, qt.Qt.DisplayRole], "") + self.assertEquals(displayed[hdf5.Hdf5TreeModel.NODE_COLUMN, qt.Qt.DisplayRole], "Dataset") + + def testDropLastAsFirst(self): + model = hdf5.Hdf5TreeModel() + h5_1 = _mock.File("/foo/bar/1.mock") + h5_2 = _mock.File("/foo/bar/2.mock") + model.insertH5pyObject(h5_1) + model.insertH5pyObject(h5_2) + self.assertEquals(self.getItemName(model, 0), "1.mock") + self.assertEquals(self.getItemName(model, 1), "2.mock") + index = model.index(1, 0, qt.QModelIndex()) + mimeData = model.mimeData([index]) + model.dropMimeData(mimeData, qt.Qt.MoveAction, 0, 0, qt.QModelIndex()) + self.assertEquals(self.getItemName(model, 0), "2.mock") + self.assertEquals(self.getItemName(model, 1), "1.mock") + + def testDropFirstAsLast(self): + model = hdf5.Hdf5TreeModel() + h5_1 = _mock.File("/foo/bar/1.mock") + h5_2 = _mock.File("/foo/bar/2.mock") + model.insertH5pyObject(h5_1) + model.insertH5pyObject(h5_2) + self.assertEquals(self.getItemName(model, 0), "1.mock") + self.assertEquals(self.getItemName(model, 1), "2.mock") + index = model.index(0, 0, qt.QModelIndex()) + mimeData = model.mimeData([index]) + model.dropMimeData(mimeData, qt.Qt.MoveAction, 2, 0, qt.QModelIndex()) + self.assertEquals(self.getItemName(model, 0), "2.mock") + self.assertEquals(self.getItemName(model, 1), "1.mock") + + def testRootParent(self): + model = hdf5.Hdf5TreeModel() + h5_1 = _mock.File("/foo/bar/1.mock") + model.insertH5pyObject(h5_1) + index = model.index(0, 0, qt.QModelIndex()) + index = model.parent(index) + self.assertEquals(index, qt.QModelIndex()) + + +class TestNexusSortFilterProxyModel(TestCaseQt): + + def getChildNames(self, model, index): + count = model.rowCount(index) + result = [] + for row in range(0, count): + itemIndex = model.index(row, hdf5.Hdf5TreeModel.NAME_COLUMN, index) + name = model.data(itemIndex, qt.Qt.DisplayRole) + result.append(name) + return result + + def testNXentryStartTime(self): + """Test NXentry with start_time""" + model = hdf5.Hdf5TreeModel() + h5 = _mock.File("/foo/bar/1.mock") + h5.create_NXentry("a").create_dataset("start_time", numpy.string_("2015")) + h5.create_NXentry("b").create_dataset("start_time", numpy.string_("2013")) + h5.create_NXentry("c").create_dataset("start_time", numpy.string_("2014")) + model.insertH5pyObject(h5) + + proxy = hdf5.NexusSortFilterProxyModel() + proxy.setSourceModel(model) + proxy.sort(0, qt.Qt.DescendingOrder) + names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex())) + self.assertListEqual(names, ["a", "c", "b"]) + + def testNXentryStartTimeInArray(self): + """Test NXentry with start_time""" + model = hdf5.Hdf5TreeModel() + h5 = _mock.File("/foo/bar/1.mock") + h5.create_NXentry("a").create_dataset("start_time", numpy.array([numpy.string_("2015")])) + h5.create_NXentry("b").create_dataset("start_time", numpy.array([numpy.string_("2013")])) + h5.create_NXentry("c").create_dataset("start_time", numpy.array([numpy.string_("2014")])) + model.insertH5pyObject(h5) + + proxy = hdf5.NexusSortFilterProxyModel() + proxy.setSourceModel(model) + proxy.sort(0, qt.Qt.DescendingOrder) + names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex())) + self.assertListEqual(names, ["a", "c", "b"]) + + def testNXentryEndTimeInArray(self): + """Test NXentry with end_time""" + model = hdf5.Hdf5TreeModel() + h5 = _mock.File("/foo/bar/1.mock") + h5.create_NXentry("a").create_dataset("end_time", numpy.array([numpy.string_("2015")])) + h5.create_NXentry("b").create_dataset("end_time", numpy.array([numpy.string_("2013")])) + h5.create_NXentry("c").create_dataset("end_time", numpy.array([numpy.string_("2014")])) + model.insertH5pyObject(h5) + + proxy = hdf5.NexusSortFilterProxyModel() + proxy.setSourceModel(model) + proxy.sort(0, qt.Qt.DescendingOrder) + names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex())) + self.assertListEqual(names, ["a", "c", "b"]) + + def testNXentryName(self): + """Test NXentry without start_time or end_time""" + model = hdf5.Hdf5TreeModel() + h5 = _mock.File("/foo/bar/1.mock") + h5.create_NXentry("a") + h5.create_NXentry("c") + h5.create_NXentry("b") + model.insertH5pyObject(h5) + + proxy = hdf5.NexusSortFilterProxyModel() + proxy.setSourceModel(model) + proxy.sort(0, qt.Qt.AscendingOrder) + names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex())) + self.assertListEqual(names, ["a", "b", "c"]) + + def testStartTime(self): + """If it is not NXentry, start_time is not used""" + model = hdf5.Hdf5TreeModel() + h5 = _mock.File("/foo/bar/1.mock") + h5.create_group("a").create_dataset("start_time", numpy.string_("2015")) + h5.create_group("b").create_dataset("start_time", numpy.string_("2013")) + h5.create_group("c").create_dataset("start_time", numpy.string_("2014")) + model.insertH5pyObject(h5) + + proxy = hdf5.NexusSortFilterProxyModel() + proxy.setSourceModel(model) + proxy.sort(0, qt.Qt.AscendingOrder) + names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex())) + self.assertListEqual(names, ["a", "b", "c"]) + + def testName(self): + model = hdf5.Hdf5TreeModel() + h5 = _mock.File("/foo/bar/1.mock") + h5.create_group("a") + h5.create_group("c") + h5.create_group("b") + model.insertH5pyObject(h5) + + proxy = hdf5.NexusSortFilterProxyModel() + proxy.setSourceModel(model) + proxy.sort(0, qt.Qt.AscendingOrder) + names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex())) + self.assertListEqual(names, ["a", "b", "c"]) + + def testNumber(self): + model = hdf5.Hdf5TreeModel() + h5 = _mock.File("/foo/bar/1.mock") + h5.create_group("a1") + h5.create_group("a20") + h5.create_group("a3") + model.insertH5pyObject(h5) + + proxy = hdf5.NexusSortFilterProxyModel() + proxy.setSourceModel(model) + proxy.sort(0, qt.Qt.AscendingOrder) + names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex())) + self.assertListEqual(names, ["a1", "a3", "a20"]) + + def testMultiNumber(self): + model = hdf5.Hdf5TreeModel() + h5 = _mock.File("/foo/bar/1.mock") + h5.create_group("a1-1") + h5.create_group("a20-1") + h5.create_group("a3-1") + h5.create_group("a3-20") + h5.create_group("a3-3") + model.insertH5pyObject(h5) + + proxy = hdf5.NexusSortFilterProxyModel() + proxy.setSourceModel(model) + proxy.sort(0, qt.Qt.AscendingOrder) + names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex())) + self.assertListEqual(names, ["a1-1", "a3-1", "a3-3", "a3-20", "a20-1"]) + + def testUnconsistantTypes(self): + model = hdf5.Hdf5TreeModel() + h5 = _mock.File("/foo/bar/1.mock") + h5.create_group("aaa100") + h5.create_group("100aaa") + model.insertH5pyObject(h5) + + proxy = hdf5.NexusSortFilterProxyModel() + proxy.setSourceModel(model) + proxy.sort(0, qt.Qt.AscendingOrder) + names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex())) + self.assertListEqual(names, ["100aaa", "aaa100"]) + + +class TestHdf5(TestCaseQt): + """Test to check that icons module.""" + + def setUp(self): + super(TestHdf5, self).setUp() + if h5py is None: + self.skipTest("h5py is not available") + + def testCreate(self): + view = hdf5.Hdf5TreeView() + self.assertIsNotNone(view) + + def testContextMenu(self): + view = hdf5.Hdf5TreeView() + view._createContextMenu(qt.QPoint(0, 0)) + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestHdf5TreeModel)) + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestNexusSortFilterProxyModel)) + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestHdf5)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/icons.py b/silx/gui/icons.py new file mode 100644 index 0000000..eaf83b8 --- /dev/null +++ b/silx/gui/icons.py @@ -0,0 +1,360 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Set of icons for buttons. + +Use :func:`getQIcon` to create Qt QIcon from the name identifying an icon. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "25/04/2017" + + +import logging +import weakref +from . import qt +from silx.resources import resource_filename +from silx.utils import weakref as silxweakref +from silx.utils.decorators import deprecated + + +_logger = logging.getLogger(__name__) +"""Module logger""" + + +_cached_icons = weakref.WeakValueDictionary() +"""Cache loaded icons in a weak structure""" + + +_supported_formats = None +"""Order of file format extension to check""" + + +class AbstractAnimatedIcon(qt.QObject): + """Store an animated icon. + + It provides an event containing the new icon everytime it is updated.""" + + def __init__(self, parent=None): + """Constructor + + :param qt.QObject parent: Parent of the QObject + :raises: ValueError when name is not known + """ + qt.QObject.__init__(self, parent) + + self.__targets = silxweakref.WeakList() + self.__currentIcon = None + + iconChanged = qt.Signal(qt.QIcon) + """Signal sent with a QIcon everytime the animation changed.""" + + def register(self, obj): + """Register an object to the AnimatedIcon. + If no object are registred, the animation is paused. + Object are stored in a weaked list. + + :param object obj: An object + """ + if obj not in self.__targets: + self.__targets.append(obj) + self._updateState() + + def unregister(self, obj): + """Remove the object from the registration. + If no object are registred the animation is paused. + + :param object obj: A registered object + """ + if obj in self.__targets: + self.__targets.remove(obj) + self._updateState() + + def hasRegistredObjects(self): + """Returns true if any object is registred. + + :rtype: bool + """ + return len(self.__targets) + + def isRegistered(self, obj): + """Returns true if the object is registred in the AnimatedIcon. + + :param object obj: An object + :rtype: bool + """ + return obj in self.__targets + + def currentIcon(self): + """Returns the icon of the current frame. + + :rtype: qt.QIcon + """ + return self.__currentIcon + + def _updateState(self): + """Update the object according to the connected objects.""" + pass + + def _setCurrentIcon(self, icon): + """Store the current icon and emit a `iconChanged` event. + + :param qt.QIcon icon: The current icon + """ + self.__currentIcon = icon + self.iconChanged.emit(self.__currentIcon) + + +class MovieAnimatedIcon(AbstractAnimatedIcon): + """Store a looping QMovie to provide icons for each frames. + Provides an event with the new icon everytime the movie frame + is updated.""" + + def __init__(self, filename, parent=None): + """Constructor + + :param str filename: An icon name to an animated format + :param qt.QObject parent: Parent of the QObject + :raises: ValueError when name is not known + """ + AbstractAnimatedIcon.__init__(self, parent) + + qfile = getQFile(filename) + self.__movie = qt.QMovie(qfile.fileName(), qt.QByteArray(), parent) + self.__movie.setCacheMode(qt.QMovie.CacheAll) + self.__movie.frameChanged.connect(self.__frameChanged) + self.__cacheIcons = {} + + self.__movie.jumpToFrame(0) + self.__updateIconAtFrame(0) + + def __frameChanged(self, frameId): + """Callback everytime the QMovie frame change + :param int frameId: Current frame id + """ + self.__updateIconAtFrame(frameId) + + def __updateIconAtFrame(self, frameId): + """ + Update the current stored QIcon + + :param int frameId: Current frame id + """ + if frameId in self.__cacheIcons: + icon = self.__cacheIcons[frameId] + else: + icon = qt.QIcon(self.__movie.currentPixmap()) + self.__cacheIcons[frameId] = icon + self._setCurrentIcon(icon) + + def _updateState(self): + """Update the movie play according to internal stat of the + AnimatedIcon.""" + self.__movie.setPaused(not self.hasRegistredObjects()) + + +class MultiImageAnimatedIcon(AbstractAnimatedIcon): + """Store a looping QMovie to provide icons for each frames. + Provides an event with the new icon everytime the movie frame + is updated.""" + + def __init__(self, filename, parent=None): + """Constructor + + :param str filename: An icon name to an animated format + :param qt.QObject parent: Parent of the QObject + :raises: ValueError when name is not known + """ + AbstractAnimatedIcon.__init__(self, parent) + + self.__frames = [] + for i in range(100): + try: + pixmap = getQPixmap("animated/%s-%02d" % (filename, i)) + except ValueError: + break + icon = qt.QIcon(pixmap) + self.__frames.append(icon) + + if len(self.__frames) == 0: + raise ValueError("Animated icon '%s' do not exists" % filename) + + self.__frameId = -1 + self.__timer = qt.QTimer(self) + self.__timer.timeout.connect(self.__increaseFrame) + self.__updateIconAtFrame(0) + + def __increaseFrame(self): + """Callback called every timer timeout to change the current frame of + the animation + """ + frameId = (self.__frameId + 1) % len(self.__frames) + self.__updateIconAtFrame(frameId) + + def __updateIconAtFrame(self, frameId): + """ + Update the current stored QIcon + + :param int frameId: Current frame id + """ + self.__frameId = frameId + icon = self.__frames[frameId] + self._setCurrentIcon(icon) + + def _updateState(self): + """Update the object to wake up or sleep it according to its use.""" + if self.hasRegistredObjects(): + if not self.__timer.isActive(): + self.__timer.start(100) + else: + if self.__timer.isActive(): + self.__timer.stop() + + +class AnimatedIcon(MovieAnimatedIcon): + """Store a looping QMovie to provide icons for each frames. + Provides an event with the new icon everytime the movie frame + is updated. + + It may not be available anymore for the silx release 0.6. + + .. deprecated:: 0.5 + Use :class:`MovieAnimatedIcon` instead. + """ + + @deprecated + def __init__(self, filename, parent=None): + MovieAnimatedIcon.__init__(self, filename, parent=parent) + + +def getWaitIcon(): + """Returns a cached version of the waiting AbstractAnimatedIcon. + + :rtype: AbstractAnimatedIcon + """ + return getAnimatedIcon("process-working") + + +def getAnimatedIcon(name): + """Create an AbstractAnimatedIcon from a name. + + Try to load a mng or a gif file, then try to load a multi-image animated + icon. + + In Qt5 mng or gif are not used. It does not take care very well of the + transparency. + + :param str name: Name of the icon, in one of the defined icons + in this module. + :return: Corresponding AbstractAnimatedIcon + :raises: ValueError when name is not known + """ + key = name + "__anim" + if key not in _cached_icons: + + qtMajorVersion = int(qt.qVersion().split(".")[0]) + icon = None + + # ignore mng and gif in Qt5 + if qtMajorVersion != 5: + try: + icon = MovieAnimatedIcon(name) + except ValueError: + icon = None + + if icon is None: + try: + icon = MultiImageAnimatedIcon(name) + except ValueError: + icon = None + + if icon is None: + raise ValueError("Not an animated icon name: %s", name) + + _cached_icons[key] = icon + else: + icon = _cached_icons[key] + return icon + + +def getQIcon(name): + """Create a QIcon from its name. + + :param str name: Name of the icon, in one of the defined icons + in this module. + :return: Corresponding QIcon + :raises: ValueError when name is not known + """ + if name not in _cached_icons: + qfile = getQFile(name) + icon = qt.QIcon(qfile.fileName()) + _cached_icons[name] = icon + else: + icon = _cached_icons[name] + return icon + + +def getQPixmap(name): + """Create a QPixmap from its name. + + :param str name: Name of the icon, in one of the defined icons + in this module. + :return: Corresponding QPixmap + :raises: ValueError when name is not known + """ + qfile = getQFile(name) + return qt.QPixmap(qfile.fileName()) + + +def getQFile(name): + """Create a QFile from an icon name. Filename is found + according to supported Qt formats. + + :param str name: Name of the icon, in one of the defined icons + in this module. + :return: Corresponding QFile + :rtype: qt.QFile + :raises: ValueError when name is not known + """ + global _supported_formats + if _supported_formats is None: + _supported_formats = [] + supported_formats = qt.supportedImageFormats() + order = ["mng", "gif", "svg", "png", "jpg"] + for format_ in order: + if format_ in supported_formats: + _supported_formats.append(format_) + if len(_supported_formats) == 0: + _logger.error("No format supported for icons") + else: + _logger.debug("Format %s supported", ", ".join(_supported_formats)) + + for format_ in _supported_formats: + format_ = str(format_) + filename = resource_filename('gui/icons/%s.%s' % (name, format_)) + qfile = qt.QFile(filename) + if qfile.exists(): + return qfile + raise ValueError('Not an icon name: %s' % name) diff --git a/silx/gui/plot/AlphaSlider.py b/silx/gui/plot/AlphaSlider.py new file mode 100644 index 0000000..ab2e5aa --- /dev/null +++ b/silx/gui/plot/AlphaSlider.py @@ -0,0 +1,300 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module defines slider widgets interacting with the transparency +of an image on a :class:`PlotWidget` + +Classes: +-------- + +- :class:`BaseAlphaSlider` (abstract class) +- :class:`NamedImageAlphaSlider` +- :class:`ActiveImageAlphaSlider` + +Example: +-------- + +This widget can, for instance, be added to a plot toolbar. + +.. code-block:: python + + import numpy + from silx.gui import qt + from silx.gui.plot import PlotWidget + from silx.gui.plot.ImageAlphaSlider import NamedImageAlphaSlider + + app = qt.QApplication([]) + pw = PlotWidget() + + img0 = numpy.arange(200*150).reshape((200, 150)) + pw.addImage(img0, legend="my background", z=0, origin=(50, 50)) + + x, y = numpy.meshgrid(numpy.linspace(-10, 10, 200), + numpy.linspace(-10, 5, 150), + indexing="ij") + img1 = numpy.asarray(numpy.sin(x * y) / (x * y), + dtype='float32') + + pw.addImage(img1, legend="my data", z=1, + replace=False) + + alpha_slider = NamedImageAlphaSlider(parent=pw, + plot=pw, + legend="my data") + alpha_slider.setOrientation(qt.Qt.Horizontal) + + toolbar = qt.QToolBar("plot", pw) + toolbar.addWidget(alpha_slider) + pw.addToolBar(toolbar) + + pw.show() + app.exec_() + +""" + +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "24/03/2017" + +import logging + +from silx.gui import qt + +_logger = logging.getLogger(__name__) + + +class BaseAlphaSlider(qt.QSlider): + """Slider widget to be used in a plot toolbar to control the + transparency of a plot primitive (image, scatter or curve). + + Internally, the slider stores its state as an integer between + 0 and 255. This is the value emitted by the :attr:`valueChanged` + signal. + + The method :meth:`getAlpha` returns the corresponding opacity/alpha + as a float between 0. and 1. (with a step of :math:`\frac{1}{255}`). + + You must subclass this class and implement :meth:`getItem`. + """ + sigAlphaChanged = qt.Signal(float) + """Emits the alpha value when the slider's value changes, + as a float between 0. and 1.""" + + def __init__(self, parent=None, plot=None): + """ + + :param parent: Parent QWidget + :param plot: Parent plot widget + """ + assert plot is not None + super(BaseAlphaSlider, self).__init__(parent) + + self.plot = plot + + self.setRange(0, 255) + + # if already connected to an item, use its alpha as initial value + if self.getItem() is None: + self.setValue(255) + self.setEnabled(False) + else: + alpha = self.getItem().getAlpha() + self.setValue(round(255*alpha)) + + self.valueChanged.connect(self._valueChanged) + + def getItem(self): + """You must implement this class to define which item + to work on. It must return an item that inherits + :class:`silx.gui.plot.items.core.AlphaMixIn`. + + :return: Item on which to operate, or None + :rtype: :class:`silx.plot.items.Item` + """ + raise NotImplementedError( + "BaseAlphaSlider must be subclassed to " + + "implement getItem()") + + def getAlpha(self): + """Get the opacity, as a float between 0. and 1. + + :return: Alpha value in [0., 1.] + :rtype: float + """ + return self.value() / 255. + + def _valueChanged(self, value): + self._updateItem() + self.sigAlphaChanged.emit(value / 255.) + + def _updateItem(self): + """Update the item's alpha channel. + """ + item = self.getItem() + if item is not None: + item.setAlpha(self.getAlpha()) + + +class ActiveImageAlphaSlider(BaseAlphaSlider): + """Slider widget to be used in a plot toolbar to control the + transparency of the **active image**. + + :param parent: Parent QWidget + :param plot: Plot on which to operate + + See documentation of :class:`BaseAlphaSlider` + """ + def __init__(self, parent=None, plot=None): + """ + + :param parent: Parent QWidget + :param plot: Plot widget on which to operate + """ + super(ActiveImageAlphaSlider, self).__init__(parent, plot) + plot.sigActiveImageChanged.connect(self._activeImageChanged) + + def getItem(self): + return self.plot.getActiveImage() + + def _activeImageChanged(self, previous, new): + """Activate or deactivate slider depending on presence of a new + active image. + Apply transparency value to new active image. + + :param previous: Legend of previous active image, or None + :param new: Legend of new active image, or None + """ + if new is not None and not self.isEnabled(): + self.setEnabled(True) + elif new is None and self.isEnabled(): + self.setEnabled(False) + + self._updateItem() + + +class NamedItemAlphaSlider(BaseAlphaSlider): + """Slider widget to be used in a plot toolbar to control the + transparency of an item (defined by its kind and legend). + + :param parent: Parent QWidget + :param plot: Plot on which to operate + :param str kind: Kind of item whose transparency is to be + controlled: "scatter", "image" or "curve". + :param str legend: Legend of item whose transparency is to be + controlled. + """ + def __init__(self, parent=None, plot=None, + kind=None, legend=None): + self._item_legend = legend + self._item_kind = kind + + super(NamedItemAlphaSlider, self).__init__(parent, plot) + + self._updateState() + plot.sigContentChanged.connect(self._onContentChanged) + + def _onContentChanged(self, action, kind, legend): + if legend == self._item_legend and kind == self._item_kind: + if action == "add": + self.setEnabled(True) + elif action == "remove": + self.setEnabled(False) + + def _updateState(self): + """Enable or disable widget based on item's availability.""" + if self.getItem() is not None: + self.setEnabled(True) + else: + self.setEnabled(False) + + def getItem(self): + """Return plot item currently associated to this widget (can be + a curve, an image, a scatter...) + + :rtype: subclass of :class:`silx.gui.plot.items.Item`""" + if self._item_legend is None or self._item_kind is None: + return None + return self.plot._getItem(kind=self._item_kind, + legend=self._item_legend) + + def setLegend(self, legend): + """Associate a different item (of the same kind) to the slider. + + :param legend: New legend of item whose transparency is to be + controlled. + """ + self._item_legend = legend + self._updateState() + + def getLegend(self): + """Return legend of the item currently controlled by this slider. + + :return: Image legend associated to the slider + """ + return self._item_kind + + def setItemKind(self, legend): + """Associate a different item (of the same kind) to the slider. + + :param legend: New legend of item whose transparency is to be + controlled. + """ + self._item_legend = legend + self._updateState() + + def getItemKind(self): + """Return kind of the item currently controlled by this slider. + + :return: Item kind ("image", "scatter"...) + :rtype: str on None + """ + return self._item_kind + + +class NamedImageAlphaSlider(NamedItemAlphaSlider): + """Slider widget to be used in a plot toolbar to control the + transparency of an image (defined by its legend). + + :param parent: Parent QWidget + :param plot: Plot on which to operate + :param str legend: Legend of image whose transparency is to be + controlled. + """ + def __init__(self, parent=None, plot=None, legend=None): + NamedItemAlphaSlider.__init__(self, parent, plot, + kind="image", legend=legend) + + +class NamedScatterAlphaSlider(NamedItemAlphaSlider): + """Slider widget to be used in a plot toolbar to control the + transparency of a scatter (defined by its legend). + + :param parent: Parent QWidget + :param plot: Plot on which to operate + :param str legend: Legend of scatter whose transparency is to be + controlled. + """ + def __init__(self, parent=None, plot=None, legend=None): + NamedItemAlphaSlider.__init__(self, parent, plot, + kind="scatter", legend=legend) diff --git a/silx/gui/plot/ColorBar.py b/silx/gui/plot/ColorBar.py new file mode 100644 index 0000000..93e3c36 --- /dev/null +++ b/silx/gui/plot/ColorBar.py @@ -0,0 +1,790 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Module containing several widgets associated to a colormap. +""" + +__authors__ = ["H. Payno", "T. Vincent"] +__license__ = "MIT" +__date__ = "11/04/2017" + + +import logging +import numpy +from ._utils import ticklayout +from ._utils import clipColormapLogRange + + +from .. import qt +from silx.gui.plot import Colors + +_logger = logging.getLogger(__name__) + + +class ColorBarWidget(qt.QWidget): + """Colorbar widget displaying a colormap + + It uses a description of colormap as dict compatible with :class:`Plot`. + + .. image:: img/linearColorbar.png + :width: 80px + :align: center + + To run the following sample code, a QApplication must be initialized. + + >>> from silx.gui.plot import Plot2D + >>> from silx.gui.plot.ColorBar import ColorBarWidget + + >>> plot = Plot2D() # Create a plot widget + >>> plot.show() + + >>> colorbar = ColorBarWidget(plot=plot, legend='Colormap') # Associate the colorbar with it + >>> colorbar.show() + + Initializer parameters: + + :param parent: See :class:`QWidget` + :param plot: PlotWidget the colorbar is attached to (optional) + :param str legend: the label to set to the colormap + """ + + def __init__(self, parent=None, plot=None, legend=None): + super(ColorBarWidget, self).__init__(parent) + self._plot = None + + self.__buildGUI() + self.setLegend(legend) + self.setPlot(plot) + + def __buildGUI(self): + self.setLayout(qt.QHBoxLayout()) + + # create color scale widget + self._colorScale = ColorScaleBar(parent=self, + colormap=None) + self.layout().addWidget(self._colorScale) + + # legend (is the right group) + self.legend = _VerticalLegend('', self) + self.layout().addWidget(self.legend) + + self.layout().setSizeConstraint(qt.QLayout.SetMinAndMaxSize) + self.setSizePolicy(qt.QSizePolicy.Minimum, qt.QSizePolicy.Expanding) + self.layout().setContentsMargins(0, 0, 0, 0) + + def getPlot(self): + """Returns the :class:`Plot` associated to this widget or None""" + return self._plot + + def setPlot(self, plot): + """Associate a plot to the ColorBar + + :param plot: the plot to associate with the colorbar. If None will remove + any connection with a previous plot. + """ + # removing previous plot if any + if self._plot is not None: + self._plot.sigActiveImageChanged.disconnect(self._activeImageChanged) + + # setting the new plot + self._plot = plot + if self._plot is not None: + self._plot.sigActiveImageChanged.connect(self._activeImageChanged) + self._activeImageChanged(self._plot.getActiveImage(just_legend=True)) + + def getColormap(self): + """Return the colormap displayed in the colorbar as a dict. + + It returns None if no colormap is set. + See :class:`silx.gui.plot.Plot` documentation for the description of the colormap + dict description. + """ + return self._colormap.copy() + + def setColormap(self, colormap): + """Set the colormap to be displayed. + + :param dict colormap: The colormap to apply on the ColorBarWidget + """ + self._colormap = colormap + if self._colormap is None: + return + + if self._colormap['normalization'] not in ('log', 'linear'): + raise ValueError('Wrong normalization %s' % self._colormap['normalization']) + + if self._colormap['normalization'] is 'log': + if self._colormap['vmin'] < 1. or self._colormap['vmax'] < 1.: + _logger.warning('Log colormap with bound <= 1: changing bounds.') + clipColormapLogRange(colormap) + + self.getColorScaleBar().setColormap(self._colormap) + + def setLegend(self, legend): + """Set the legend displayed along the colorbar + + :param str legend: The label + """ + if legend is None or legend == "": + self.legend.hide() + self.legend.setText("") + else: + assert(type(legend) is str) + self.legend.show() + self.legend.setText(legend) + + def getLegend(self): + """ + Returns the legend displayed along the colorbar + + :return: return the legend displayed along the colorbar + :rtype: str + """ + return self.legend.getText() + + def _activeImageChanged(self, legend): + """Handle plot active curve changed""" + if legend is None: # No active image, display default colormap + self._syncWithDefaultColormap() + return + + # Sync with active image + image = self._plot.getActiveImage().getData(copy=False) + + # RGB(A) image, display default colormap + if image.ndim != 2: + self._syncWithDefaultColormap() + return + + # data image, sync with image colormap + # do we need the copy here : used in the case we are changing + # vmin and vmax but should have already be done by the plot + cmap = self._plot.getActiveImage().getColormap().copy() + if cmap['autoscale']: + if cmap['normalization'] == 'log': + data = image[ + numpy.logical_and(image > 0, numpy.isfinite(image))] + else: + data = image[numpy.isfinite(image)] + cmap['vmin'], cmap['vmax'] = data.min(), data.max() + + self.setColormap(cmap) + + def _defaultColormapChanged(self): + """Handle plot default colormap changed""" + if self._plot.getActiveImage() is None: + # No active image, take default colormap update into account + self._syncWithDefaultColormap() + + def _syncWithDefaultColormap(self): + """Update colorbar according to plot default colormap""" + self.setColormap(self._plot.getDefaultColormap()) + + def getColorScaleBar(self): + """ + + :return: return the :class:`ColorScaleBar` used to display ColorScale + and ticks""" + return self._colorScale + + +class _VerticalLegend(qt.QLabel): + """Display vertically the given text + """ + def __init__(self, text, parent=None): + """ + + :param text: the legend + :param parent: the Qt parent if any + """ + qt.QLabel.__init__(self, text, parent) + self.setLayout(qt.QVBoxLayout()) + self.layout().setContentsMargins(0, 0, 0, 0) + + def paintEvent(self, event): + painter = qt.QPainter(self) + painter.setFont(self.font()) + + painter.translate(0, self.rect().height()) + painter.rotate(270) + newRect = qt.QRect(0, 0, self.rect().height(), self.rect().width()) + + painter.drawText(newRect, qt.Qt.AlignHCenter, self.text()) + + fm = qt.QFontMetrics(self.font()) + preferedHeight = fm.width(self.text()) + preferedWidth = fm.height() + self.setFixedWidth(preferedWidth) + self.setMinimumHeight(preferedHeight) + + +class ColorScaleBar(qt.QWidget): + """This class is making the composition of a :class:`_ColorScale` and a + :class:`_TickBar`. + + It is the simplest widget displaying ticks and colormap gradient. + + .. image:: img/colorScaleBar.png + :width: 150px + :align: center + + To run the following sample code, a QApplication must be initialized. + + >>> colormap={'name':'gray', + ... 'normalization':'log', + ... 'vmin':1, + ... 'vmax':100000, + ... 'autoscale':False + ... } + >>> colorscale = ColorScaleBar(parent=None, + ... colormap=colormap ) + >>> colorscale.show() + + Initializer parameters : + + :param colormap: the colormap to be displayed + :param parent: the Qt parent if any + :param displayTicksValues: display the ticks value or only the '-' + """ + + _TEXT_MARGIN = 5 + """The tick bar need a margin to display all labels at the correct place. + So the ColorScale should have the same margin in order for both to fit""" + + _MIN_LIM_SCI_FORM = -1000 + """Used for the min and max label to know when we should display it under + the scientific form""" + + _MAX_LIM_SCI_FORM = 1000 + """Used for the min and max label to know when we should display it under + the scientific form""" + + def __init__(self, parent=None, colormap=None, displayTicksValues=True): + super(ColorScaleBar, self).__init__(parent) + + self.minVal = None + """Value set to the _minLabel""" + self.maxVal = None + """Value set to the _maxLabel""" + + self.setLayout(qt.QGridLayout()) + + # create the left side group (ColorScale) + self.colorScale = _ColorScale(colormap=colormap, + parent=self, + margin=ColorScaleBar._TEXT_MARGIN) + + self.tickbar = _TickBar(vmin=colormap['vmin'] if colormap else 0.0, + vmax=colormap['vmax'] if colormap else 1.0, + norm=colormap['normalization'] if colormap else 'linear', + parent=self, + displayValues=displayTicksValues, + margin=ColorScaleBar._TEXT_MARGIN) + + self.layout().addWidget(self.tickbar, 1, 0) + self.layout().addWidget(self.colorScale, 1, 1) + + self.layout().setContentsMargins(0, 0, 0, 0) + self.layout().setSpacing(0) + + # max label + self._maxLabel = qt.QLabel(str(1.0), parent=self) + self._maxLabel.setAlignment(qt.Qt.AlignHCenter) + self._maxLabel.setSizePolicy(qt.QSizePolicy.Minimum, qt.QSizePolicy.Minimum) + self.layout().addWidget(self._maxLabel, 0, 1) + + # min label + self._minLabel = qt.QLabel(str(0.0), parent=self) + self._minLabel.setAlignment(qt.Qt.AlignHCenter) + self._minLabel.setSizePolicy(qt.QSizePolicy.Minimum, qt.QSizePolicy.Minimum) + self.layout().addWidget(self._minLabel, 2, 1) + + def getTickBar(self): + """ + + :return: the instanciation of the :class:`_TickBar` + """ + return self.tickbar + + def getColorScale(self): + """ + + :return: the instanciation of the :class:`_ColorScale` + """ + return self.colorScale + + def setColormap(self, colormap): + """Set the new colormap to be displayed + + :param dict colormap: the colormap to set + """ + if colormap is not None: + self.colorScale.setColormap(colormap) + + self.tickbar.update(vmin=colormap['vmin'], + vmax=colormap['vmax'], + norm=colormap['normalization']) + + self._setMinMaxLabels(colormap['vmin'], colormap['vmax']) + + def setMinMaxVisible(self, val=True): + """Change visibility of the min label and the max label + + :param val: if True, set the labels visible, otherwise set it not visible + """ + self._maxLabel.show() if val is True else self._maxLabel.hide() + self._minLabel.show() if val is True else self._minLabel.hide() + + def _updateMinMax(self): + """Update the min and max label if we are in the case of the + configuration 'minMaxValueOnly'""" + if self._minLabel is not None and self._maxLabel is not None: + if self.minVal is not None: + if ColorScaleBar._MIN_LIM_SCI_FORM <= self.minVal <= ColorScaleBar._MAX_LIM_SCI_FORM: + self._minLabel.setText(str(self.minVal)) + else: + self._minLabel.setText("{0:.0e}".format(self.minVal)) + if self.maxVal is not None: + if ColorScaleBar._MIN_LIM_SCI_FORM <= self.maxVal <= ColorScaleBar._MAX_LIM_SCI_FORM: + self._maxLabel.setText(str(self.maxVal)) + else: + self._maxLabel.setText("{0:.0e}".format(self.maxVal)) + + def _setMinMaxLabels(self, minVal, maxVal): + """Change the value of the min and max labels to be displayed. + + :param minVal: the minimal value of the TickBar (not str) + :param maxVal: the maximal value of the TickBar (not str) + """ + # bad hack to try to display has much information as possible + self.minVal = minVal + self.maxVal = maxVal + self._updateMinMax() + + def resizeEvent(self, event): + qt.QWidget.resizeEvent(self, event) + self._updateMinMax() + + +class _ColorScale(qt.QWidget): + """Widget displaying the colormap colorScale. + + Show matching value between the gradient color (from the colormap) at mouse + position and value. + + .. image:: img/colorScale.png + :width: 20px + :align: center + + + To run the following sample code, a QApplication must be initialized. + + >>> colormap={'name':'viridis', + ... 'normalization':'log', + ... 'vmin':1, + ... 'vmax':100000, + ... 'autoscale':False + ... } + >>> colorscale = ColorScale(parent=None, + ... colormap=colormap) + >>> colorscale.show() + + Initializer parameters : + + :param colormap: the colormap to be displayed + :param parent: the Qt parent if any + :param int margin: the top and left margin to apply. + + .. warning:: Value drawing will be + done at the center of ticks. So if no margin is done your values + drawing might not be fully done for extrems values. + """ + + _NB_CONTROL_POINTS = 256 + + def __init__(self, colormap, parent=None, margin=5): + qt.QWidget.__init__(self, parent) + self.colormap = None + self.setColormap(colormap) + + self.setLayout(qt.QVBoxLayout()) + self.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Expanding) + # needed to get the mouse event without waiting for button click + self.setMouseTracking(True) + self.setMargin(margin) + self.setContentsMargins(0, 0, 0, 0) + + def setColormap(self, colormap): + """Set the new colormap to be displayed + + :param dict colormap: the colormap to set + """ + if colormap is None: + return + + if colormap['normalization'] not in ('log', 'linear'): + raise ValueError("Unrecognized normalization, should be 'linear' or 'log'") + + if colormap['normalization'] is 'log': + if not (colormap['vmin'] > 0 and colormap['vmax'] > 0): + raise ValueError('vmin and vmax should be positives') + self.colormap = colormap + self._computeColorPoints() + + def _computeColorPoints(self): + """Compute the color points for the gradient + """ + if self.colormap is None: + return + + vmin = self.colormap['vmin'] + vmax = self.colormap['vmax'] + steps = (vmax - vmin)/float(_ColorScale._NB_CONTROL_POINTS) + self.ctrPoints = numpy.arange(vmin, vmax, steps) + self.colorsCtrPts = Colors.applyColormapToData(self.ctrPoints, + name=self.colormap['name'], + normalization='linear', + autoscale=self.colormap['autoscale'], + vmin=vmin, + vmax=vmax) + + def paintEvent(self, event): + """""" + qt.QWidget.paintEvent(self, event) + if self.colormap is None: + return + + vmin = self.colormap['vmin'] + vmax = self.colormap['vmax'] + + painter = qt.QPainter(self) + gradient = qt.QLinearGradient(0, 0, 0, self.rect().height() - 2*self.margin) + for iPt, pt in enumerate(self.ctrPoints): + colormapPosition = 1 - (pt-vmin) / (vmax-vmin) + assert(colormapPosition >= 0.0) + assert(colormapPosition <= 1.0) + gradient.setColorAt(colormapPosition, qt.QColor(*(self.colorsCtrPts[iPt]))) + + painter.setBrush(gradient) + painter.drawRect( + qt.QRect(0, self.margin, self.width(), self.height() - 2.*self.margin)) + + def mouseMoveEvent(self, event): + """""" + self.setToolTip(str(self.getValueFromRelativePosition(self._getRelativePosition(event.y())))) + super(_ColorScale, self).mouseMoveEvent(event) + + def _getRelativePosition(self, yPixel): + """yPixel : pixel position into _ColorScale widget reference + """ + # widgets are bottom-top referencial but we display in top-bottom referential + return 1 - float(yPixel)/float(self.height() - 2*self.margin) + + def getValueFromRelativePosition(self, value): + """Return the value in the colorMap from a relative position in the + ColorScaleBar (y) + + :param value: float value in [0, 1] + :return: the value in [colormap['vmin'], colormap['vmax']] + """ + value = max(0.0, value) + value = min(value, 1.0) + vmin = self.colormap['vmin'] + vmax = self.colormap['vmax'] + if self.colormap['normalization'] is 'linear': + return vmin + (vmax - vmin) * value + elif self.colormap['normalization'] is 'log': + rpos = (numpy.log10(vmax) - numpy.log10(vmin)) * value + numpy.log10(vmin) + return numpy.power(10., rpos) + else: + err = "normalization type (%s) is not managed by the _ColorScale Widget" % self.colormap['normalization'] + raise ValueError(err) + + def setMargin(self, margin): + """Define the margin to fit with a TickBar object. + This is needed since we can only paint on the viewport of the widget. + Didn't work with a simple setContentsMargins + + :param int margin: the margin to apply on the top and bottom. + """ + self.margin = margin + + +class _TickBar(qt.QWidget): + """Bar grouping the ticks displayed + + To run the following sample code, a QApplication must be initialized. + + >>> bar = TickBar(1, 1000, norm='log', parent=None, displayValues=True) + >>> bar.show() + + .. image:: img/tickbar.png + :width: 40px + :align: center + + :param int vmin: smaller value of the range of values + :param int vmax: higher value of the range of values + :param str norm: normalization type to be displayed. Valid values are + 'linear' and 'log' + :param parent: the Qt parent if any + :param bool displayValues: if True display the values close to the tick, + Otherwise only signal it by '-' + :param int nticks: the number of tick we want to display. Should be an + unsigned int ot None. If None, let the Tick bar find the optimal + number of ticks from the tick density. + :param int margin: margin to set on the top and bottom + """ + _WIDTH_DISP_VAL = 45 + """widget width when displayed with ticks labels""" + _WIDTH_NO_DISP_VAL = 10 + """widget width when displayed without ticks labels""" + _FONT_SIZE = 10 + """font size for ticks labels""" + _LINE_WIDTH = 10 + """width of the line to mark a tick""" + + DEFAULT_TICK_DENSITY = 0.015 + + def __init__(self, vmin, vmax, norm, parent=None, displayValues=True, + nticks=None, margin=5): + super(_TickBar, self).__init__(parent) + self._forcedDisplayType = None + self.ticksDensity = _TickBar.DEFAULT_TICK_DENSITY + + self._vmin = vmin + self._vmax = vmax + # TODO : should be grouped into a global function, called by all + # logScale displayer to make sure we have the same behavior everywhere + if self._vmin < 1. or self._vmax < 1.: + _logger.warning( + 'Log colormap with bound <= 1: changing bounds.') + self._vmin, self._vmax = 1., 10. + + self._norm = norm + self.displayValues = displayValues + self.setTicksNumber(nticks) + self.setMargin(margin) + + self.setLayout(qt.QVBoxLayout()) + self.setMargin(margin) + self.setContentsMargins(0, 0, 0, 0) + + self._resetWidth() + + def setTicksValuesVisible(self, val): + self.displayValues = val + self._resetWidth() + + def _resetWidth(self): + self.width = _TickBar._WIDTH_DISP_VAL if self.displayValues else _TickBar._WIDTH_NO_DISP_VAL + self.setFixedWidth(self.width) + + def update(self, vmin, vmax, norm): + self._vmin = vmin + self._vmax = vmax + self._norm = norm + self.computeTicks() + qt.QWidget.update(self) + + def setMargin(self, margin): + """Define the margin to fit with a _ColorScale object. + This is needed since we can only paint on the viewport of the widget + + :param int margin: the margin to apply on the top and bottom. + """ + self.margin = margin + + def setTicksNumber(self, nticks): + """Set the number of ticks to display. + + :param nticks: the number of tick to be display. Should be an + unsigned int ot None. If None, let the :class:`_TickBar` find the + optimal number of ticks from the tick density. + """ + self._nticks = nticks + self.ticks = None + self.computeTicks() + qt.QWidget.update(self) + + def setTicksDensity(self, density): + """If you let :class:`_TickBar` deal with the number of ticks + (nticks=None) then you can specify a ticks density to be displayed. + """ + if density < 0.0: + raise ValueError('Density should be a positive value') + self.ticksDensity = density + + def computeTicks(self): + """This function compute ticks values labels. It is called at each + update and each resize event. + Deal only with linear and log scale. + """ + nticks = self._nticks + if nticks is None: + nticks = self._getOptimalNbTicks() + + if self._norm == 'log': + self._computeTicksLog(nticks) + elif self._norm == 'linear': + self._computeTicksLin(nticks) + else: + err = 'TickBar - Wrong normalization %s' % self._norm + raise ValueError(err) + # update the form + font = qt.QFont() + font.setPixelSize(_TickBar._FONT_SIZE) + + self.form = self._getFormat(font) + + def _computeTicksLog(self, nticks): + logMin = numpy.log10(self._vmin) + logMax = numpy.log10(self._vmax) + lowBound, highBound, spacing, self._nfrac = ticklayout.niceNumbersForLog10(logMin, + logMax, + nticks) + self.ticks = numpy.power(10., numpy.arange(lowBound, highBound, spacing)) + if spacing == 1: + self.subTicks = ticklayout.computeLogSubTicks(ticks=self.ticks, + lowBound=numpy.power(10., lowBound), + highBound=numpy.power(10., highBound)) + else: + self.subTicks = [] + + def resizeEvent(self, event): + qt.QWidget.resizeEvent(self, event) + self.computeTicks() + + def _computeTicksLin(self, nticks): + _min, _max, _spacing, self._nfrac = ticklayout.niceNumbers(self._vmin, + self._vmax, + nticks) + + self.ticks = numpy.arange(_min, _max, _spacing) + self.subTicks = [] + + def _getOptimalNbTicks(self): + return max(2, int(round(self.ticksDensity * self.rect().height()))) + + def paintEvent(self, event): + painter = qt.QPainter(self) + font = painter.font() + font.setPixelSize(_TickBar._FONT_SIZE) + painter.setFont(font) + + # paint ticks + if self.ticks is not None: + for val in self.ticks: + self._paintTick(val, painter, majorTick=True) + + # paint subticks + for val in self.subTicks: + self._paintTick(val, painter, majorTick=False) + + qt.QWidget.paintEvent(self, event) + + def _getRelativePosition(self, val): + """Return the relative position of val according to min and max value + """ + if self._norm == 'linear': + return 1 - (val - self._vmin) / (self._vmax - self._vmin) + elif self._norm == 'log': + return 1 - (numpy.log10(val) - numpy.log10(self._vmin))/(numpy.log10(self._vmax) - numpy.log(self._vmin)) + else: + raise ValueError('Norm is not recognized') + + def _paintTick(self, val, painter, majorTick=True): + """ + + :param bool majorTick: if False will never draw text and will set a line + with a smaller width + """ + fm = qt.QFontMetrics(painter.font()) + viewportHeight = self.rect().height() - self.margin * 2 + relativePos = self._getRelativePosition(val) + height = viewportHeight * relativePos + height += self.margin + lineWidth = _TickBar._LINE_WIDTH + if majorTick is False: + lineWidth /= 2 + + painter.drawLine(qt.QLine(self.width - lineWidth, + height, + self.width, + height)) + + if self.displayValues and majorTick is True: + painter.drawText(qt.QPoint(0.0, height + (fm.height() / 2)), + self.form.format(val)) + + def setDisplayType(self, disType): + """Set the type of display we want to set for ticks labels + + :param str disType: The type of display we want to set. disType values + can be : + + - 'std' for standard, meaning only a formatting on the number of + digits is done + - 'e' for scientific display + - None to let the _TickBar guess the best display for this kind of data. + """ + if disType not in (None, 'std', 'e'): + raise ValueError("display type not recognized, value should be in (None, 'std', 'e'") + self._forcedDisplayType = disType + + def _getStandardFormat(self): + return "{0:.%sf}" % self._nfrac + + def _getFormat(self, font): + if self._forcedDisplayType is None: + return self._guessType(font) + elif self._forcedDisplayType is 'std': + return self._getStandardFormat() + elif self._forcedDisplayType is 'e': + return self._getScientificForm() + else: + err = 'Forced type for display %s is not recognized' % self._forcedDisplayType + raise ValueError(err) + + def _getScientificForm(self): + return "{0:.0e}" + + def _guessType(self, font): + """Try fo find the better format to display the tick's labels + + :param QFont font: the font we want want to use durint the painting + """ + assert(type(self._vmin) == type(self._vmax)) + form = self._getStandardFormat() + + fm = qt.QFontMetrics(font) + width = 0 + for tick in self.ticks: + width = max(fm.width(form.format(tick)), width) + + # if the length of the string are too long we are mooving to scientific + # display + if width > _TickBar._WIDTH_DISP_VAL - _TickBar._LINE_WIDTH: + return self._getScientificForm() + else: + return form diff --git a/silx/gui/plot/ColormapDialog.py b/silx/gui/plot/ColormapDialog.py new file mode 100644 index 0000000..ad1425c --- /dev/null +++ b/silx/gui/plot/ColormapDialog.py @@ -0,0 +1,506 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""A QDialog widget to set-up the colormap. + +It uses a description of colormaps as dict compatible with :class:`Plot`. + +To run the following sample code, a QApplication must be initialized. + +Create the colormap dialog and set the colormap description and data range: + +>>> from silx.gui.plot.ColormapDialog import ColormapDialog + +>>> dialog = ColormapDialog() + +>>> dialog.setColormap(name='red', normalization='log', +... autoscale=False, vmin=1., vmax=2.) +>>> dialog.setDataRange(1., 100.) # This scale the width of the plot area +>>> dialog.show() + +Get the colormap description (compatible with :class:`Plot`) from the dialog: + +>>> cmap = dialog.getColormap() +>>> cmap['name'] +'red' + +It is also possible to display an histogram of the image in the dialog. +This updates the data range with the range of the bins. + +>>> import numpy +>>> image = numpy.random.normal(size=512 * 512).reshape(512, -1) +>>> hist, bin_edges = numpy.histogram(image, bins=10) +>>> dialog.setHistogram(hist, bin_edges) + +The updates of the colormap description are also available through the signal: +:attr:`ColormapDialog.sigColormapChanged`. +""" # noqa + +from __future__ import division + +__authors__ = ["V.A. Sole", "T. Vincent"] +__license__ = "MIT" +__date__ = "29/03/2016" + + +import logging + +import numpy + +from .. import qt +from . import PlotWidget + + +_logger = logging.getLogger(__name__) + + +class _FloatEdit(qt.QLineEdit): + """Field to edit a float value. + + :param parent: See :class:`QLineEdit` + :param float value: The value to set the QLineEdit to. + """ + def __init__(self, parent=None, value=None): + qt.QLineEdit.__init__(self, parent) + self.setValidator(qt.QDoubleValidator()) + self.setAlignment(qt.Qt.AlignRight) + if value is not None: + self.setValue(value) + + def value(self): + """Return the QLineEdit current value as a float.""" + return float(self.text()) + + def setValue(self, value): + """Set the current value of the LineEdit + + :param float value: The value to set the QLineEdit to. + """ + self.setText('%g' % value) + + +class ColormapDialog(qt.QDialog): + """A QDialog widget to set the colormap. + + :param parent: See :class:`QDialog` + :param str title: The QDialog title + """ + + sigColormapChanged = qt.Signal(dict) + """Signal triggered when the colormap is changed. + + It provides a dict describing the colormap to the slot. + This dict can be used with :class:`Plot`. + """ + + def __init__(self, parent=None, title="Colormap Dialog"): + qt.QDialog.__init__(self, parent) + self.setWindowTitle(title) + + self._histogramData = None + self._dataRange = None + self._minMaxWasEdited = False + + self._colormapList = ( + 'gray', 'reversed gray', + 'temperature', 'red', 'green', 'blue', 'jet', + 'viridis', 'magma', 'inferno', 'plasma') + + # Make the GUI + vLayout = qt.QVBoxLayout(self) + + formWidget = qt.QWidget() + vLayout.addWidget(formWidget) + formLayout = qt.QFormLayout(formWidget) + formLayout.setContentsMargins(10, 10, 10, 10) + formLayout.setSpacing(0) + + # Colormap row + self._comboBoxColormap = qt.QComboBox() + for cmap in self._colormapList: + # Capitalize first letters + cmap = ' '.join(w[0].upper() + w[1:] for w in cmap.split()) + self._comboBoxColormap.addItem(cmap) + self._comboBoxColormap.activated[int].connect(self._notify) + formLayout.addRow('Colormap:', self._comboBoxColormap) + + # Normalization row + self._normButtonLinear = qt.QRadioButton('Linear') + self._normButtonLinear.setChecked(True) + self._normButtonLog = qt.QRadioButton('Log') + + normButtonGroup = qt.QButtonGroup(self) + normButtonGroup.setExclusive(True) + normButtonGroup.addButton(self._normButtonLinear) + normButtonGroup.addButton(self._normButtonLog) + normButtonGroup.buttonClicked[int].connect(self._notify) + + normLayout = qt.QHBoxLayout() + normLayout.setContentsMargins(0, 0, 0, 0) + normLayout.setSpacing(10) + normLayout.addWidget(self._normButtonLinear) + normLayout.addWidget(self._normButtonLog) + + formLayout.addRow('Normalization:', normLayout) + + # Range row + self._rangeAutoscaleButton = qt.QCheckBox('Autoscale') + self._rangeAutoscaleButton.setChecked(True) + self._rangeAutoscaleButton.toggled.connect(self._autoscaleToggled) + self._rangeAutoscaleButton.clicked.connect(self._notify) + formLayout.addRow('Range:', self._rangeAutoscaleButton) + + # Min row + self._minValue = _FloatEdit(value=1.) + self._minValue.setEnabled(False) + self._minValue.textEdited.connect(self._minMaxTextEdited) + self._minValue.editingFinished.connect(self._minEditingFinished) + formLayout.addRow('\tMin:', self._minValue) + + # Max row + self._maxValue = _FloatEdit(value=10.) + self._maxValue.setEnabled(False) + self._maxValue.textEdited.connect(self._minMaxTextEdited) + self._maxValue.editingFinished.connect(self._maxEditingFinished) + formLayout.addRow('\tMax:', self._maxValue) + + # Add plot for histogram + self._plotInit() + vLayout.addWidget(self._plot) + + # Close button + buttonsWidget = qt.QWidget() + vLayout.addWidget(buttonsWidget) + + buttonsLayout = qt.QHBoxLayout(buttonsWidget) + + okButton = qt.QPushButton('OK') + okButton.clicked.connect(self.accept) + buttonsLayout.addWidget(okButton) + + cancelButton = qt.QPushButton('Cancel') + cancelButton.clicked.connect(self.reject) + buttonsLayout.addWidget(cancelButton) + + # colormap window can not be resized + self.setFixedSize(vLayout.minimumSize()) + + # Set the colormap to default values + self.setColormap(name='gray', normalization='linear', + autoscale=True, vmin=1., vmax=10.) + + def _plotInit(self): + """Init the plot to display the range and the values""" + self._plot = PlotWidget() + self._plot.setDataMargins(yMinMargin=0.125, yMaxMargin=0.125) + self._plot.setGraphXLabel("Data Values") + self._plot.setGraphYLabel("") + self._plot.setInteractiveMode('select', zoomOnWheel=False) + self._plot.setActiveCurveHandling(False) + self._plot.setMinimumSize(qt.QSize(250, 200)) + self._plot.sigPlotSignal.connect(self._plotSlot) + self._plot.hide() + + self._plotUpdate() + + def _plotUpdate(self, updateMarkers=True): + """Update the plot content + + :param bool updateMarkers: True to update markers, False otherwith + """ + dataRange = self.getDataRange() + + if dataRange is None: + if self._plot.isVisibleTo(self): + self._plot.setVisible(False) + self.setFixedSize(self.layout().minimumSize()) + return + + if not self._plot.isVisibleTo(self): + self._plot.setVisible(True) + self.setFixedSize(self.layout().minimumSize()) + + dataMin, dataMax = dataRange + marge = (abs(dataMax) + abs(dataMin)) / 6.0 + minmd = dataMin - marge + maxpd = dataMax + marge + + start, end = self._minValue.value(), self._maxValue.value() + + if start <= end: + x = [minmd, start, end, maxpd] + y = [0, 0, 1, 1] + + else: + x = [minmd, end, start, maxpd] + y = [1, 1, 0, 0] + + # Display the colormap on the side + # colormap = {'name': self.getColormap()['name'], + # 'normalization': self.getColormap()['normalization'], + # 'autoscale': True, 'vmin': 1., 'vmax': 256.} + # self._plot.addImage((1 + numpy.arange(256)).reshape(256, -1), + # xScale=(minmd - marge, marge), + # yScale=(1., 2./256.), + # legend='colormap', + # colormap=colormap) + + self._plot.addCurve(x, y, + legend="ConstrainedCurve", + color='black', + symbol='o', + linestyle='-', + resetzoom=False) + + draggable = not self._rangeAutoscaleButton.isChecked() + + if updateMarkers: + self._plot.addXMarker( + self._minValue.value(), + legend='Min', + text='Min', + draggable=draggable, + color='blue', + constraint=self._plotMinMarkerConstraint) + + self._plot.addXMarker( + self._maxValue.value(), + legend='Max', + text='Max', + draggable=draggable, + color='blue', + constraint=self._plotMaxMarkerConstraint) + + self._plot.resetZoom() + + def _plotMinMarkerConstraint(self, x, y): + """Constraint of the min marker""" + return min(x, self._maxValue.value()), y + + def _plotMaxMarkerConstraint(self, x, y): + """Constraint of the max marker""" + return max(x, self._minValue.value()), y + + def _plotSlot(self, event): + """Handle events from the plot""" + if event['event'] in ('markerMoving', 'markerMoved'): + value = float(str(event['xdata'])) + if event['label'] == 'Min': + self._minValue.setValue(value) + elif event['label'] == 'Max': + self._maxValue.setValue(value) + + # This will recreate the markers while interacting... + # It might break if marker interaction is changed + if event['event'] == 'markerMoved': + self._notify() + else: + self._plotUpdate(updateMarkers=False) + + def getHistogram(self): + """Returns the counts and bin edges of the displayed histogram. + + :return: (hist, bin_edges) + :rtype: 2-tuple of numpy arrays""" + if self._histogramData is None: + return None + else: + bins, counts = self._histogramData + return numpy.array(bins, copy=True), numpy.array(counts, copy=True) + + def setHistogram(self, hist=None, bin_edges=None): + """Set the histogram to display. + + This update the data range with the bounds of the bins. + See :meth:`setDataRange`. + + :param hist: array-like of counts or None to hide histogram + :param bin_edges: array-like of bins edges or None to hide histogram + """ + if hist is None or bin_edges is None: + self._histogramData = None + self._plot.remove(legend='Histogram', kind='curve') + self.setDataRange() # Remove data range + + else: + hist = numpy.array(hist, copy=True) + bin_edges = numpy.array(bin_edges, copy=True) + self._histogramData = hist, bin_edges + + # For now, draw the histogram as a curve + # using bin centers and normalised counts + bins_center = 0.5 * (bin_edges[:-1] + bin_edges[1:]) + norm_hist = hist / max(hist) + self._plot.addCurve(bins_center, norm_hist, + legend="Histogram", + color='gray', + symbol='', + linestyle='-', + fill=True) + + # Update the data range + self.setDataRange(bin_edges[0], bin_edges[-1]) + + def getDataRange(self): + """Returns the data range used for the histogram area. + + :return: (dataMin, dataMax) or None if no data range is set + :rtype: 2-tuple of float + """ + return self._dataRange + + def setDataRange(self, min_=None, max_=None): + """Set the range of data to use for the range of the histogram area. + + :param float min_: The min of the data or None to disable range. + :param float max_: The max of the data or None to disable range. + """ + if min_ is None or max_ is None: + self._dataRange = None + self._plotUpdate() + + else: + min_, max_ = float(min_), float(max_) + assert min_ <= max_ + self._dataRange = min_, max_ + if self._rangeAutoscaleButton.isChecked(): + self._minValue.setValue(min_) + self._maxValue.setValue(max_) + self._notify() + else: + self._plotUpdate() + + def getColormap(self): + """Return the colormap description as a dict. + + See :class:`Plot` for documentation on the colormap dict. + """ + isNormLinear = self._normButtonLinear.isChecked() + colormap = { + 'name': str(self._comboBoxColormap.currentText()).lower(), + 'normalization': 'linear' if isNormLinear else 'log', + 'autoscale': self._rangeAutoscaleButton.isChecked(), + 'vmin': self._minValue.value(), + 'vmax': self._maxValue.value()} + return colormap + + def setColormap(self, name=None, normalization=None, + autoscale=None, vmin=None, vmax=None, colors=None): + """Set the colormap description + + If some arguments are not provided, the current values are used. + + :param str name: The name of the colormap + :param str normalization: 'linear' or 'log' + :param bool autoscale: Toggle colormap range autoscale + :param float vmin: The min value, ignored if autoscale is True + :param float vmax: The max value, ignored if autoscale is True + """ + if name is not None: + assert name in self._colormapList + index = self._colormapList.index(name) + self._comboBoxColormap.setCurrentIndex(index) + + if normalization is not None: + assert normalization in ('linear', 'log') + self._normButtonLinear.setChecked(normalization == 'linear') + self._normButtonLog.setChecked(normalization == 'log') + + if vmin is not None: + self._minValue.setValue(vmin) + + if vmax is not None: + self._maxValue.setValue(vmax) + + if autoscale is not None: + self._rangeAutoscaleButton.setChecked(autoscale) + if autoscale: + dataRange = self.getDataRange() + if dataRange is not None: + self._minValue.setValue(dataRange[0]) + self._maxValue.setValue(dataRange[1]) + + # Do it once for all the changes + self._notify() + + def _notify(self, *args, **kwargs): + """Emit the signal for colormap change""" + self._plotUpdate() + self.sigColormapChanged.emit(self.getColormap()) + + def _autoscaleToggled(self, checked): + """Handle autoscale changes by enabling/disabling min/max fields""" + self._minValue.setEnabled(not checked) + self._maxValue.setEnabled(not checked) + if checked: + dataRange = self.getDataRange() + if dataRange is not None: + self._minValue.setValue(dataRange[0]) + self._maxValue.setValue(dataRange[1]) + + def _minMaxTextEdited(self, text): + """Handle _minValue and _maxValue textEdited signal""" + self._minMaxWasEdited = True + + def _minEditingFinished(self): + """Handle _minValue editingFinished signal + + Together with :meth:`_minMaxTextEdited`, this avoids to notify + colormap change when the min and max value where not edited. + """ + if self._minMaxWasEdited: + self._minMaxWasEdited = False + + # Fix start value + if self._minValue.value() > self._maxValue.value(): + self._minValue.setValue(self._maxValue.value()) + self._notify() + + def _maxEditingFinished(self): + """Handle _maxValue editingFinished signal + + Together with :meth:`_minMaxTextEdited`, this avoids to notify + colormap change when the min and max value where not edited. + """ + if self._minMaxWasEdited: + self._minMaxWasEdited = False + + # Fix end value + if self._minValue.value() > self._maxValue.value(): + self._maxValue.setValue(self._minValue.value()) + self._notify() + + def keyPressEvent(self, event): + """Override key handling. + + It disables leaving the dialog when editing a text field. + """ + if event.key() == qt.Qt.Key_Enter and (self._minValue.hasFocus() or + self._maxValue.hasFocus()): + # Bypass QDialog keyPressEvent + # To avoid leaving the dialog when pressing enter on a text field + super(qt.QDialog, self).keyPressEvent(event) + else: + # Use QDialog keyPressEvent + super(ColormapDialog, self).keyPressEvent(event) diff --git a/silx/gui/plot/Colors.py b/silx/gui/plot/Colors.py new file mode 100644 index 0000000..7a3cd97 --- /dev/null +++ b/silx/gui/plot/Colors.py @@ -0,0 +1,359 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Color conversion function, color dictionary and colormap tools.""" + +__authors__ = ["V.A. Sole", "T. VINCENT"] +__license__ = "MIT" +__date__ = "16/01/2017" + + +import logging + +import numpy + +import matplotlib +import matplotlib.colors +import matplotlib.cm + +from . import MPLColormap + + +_logger = logging.getLogger(__name__) + + +COLORDICT = {} +"""Dictionary of common colors.""" + +COLORDICT['b'] = COLORDICT['blue'] = '#0000ff' +COLORDICT['r'] = COLORDICT['red'] = '#ff0000' +COLORDICT['g'] = COLORDICT['green'] = '#00ff00' +COLORDICT['k'] = COLORDICT['black'] = '#000000' +COLORDICT['w'] = COLORDICT['white'] = '#ffffff' +COLORDICT['pink'] = '#ff66ff' +COLORDICT['brown'] = '#a52a2a' +COLORDICT['orange'] = '#ff9900' +COLORDICT['violet'] = '#6600ff' +COLORDICT['gray'] = COLORDICT['grey'] = '#a0a0a4' +# COLORDICT['darkGray'] = COLORDICT['darkGrey'] = '#808080' +# COLORDICT['lightGray'] = COLORDICT['lightGrey'] = '#c0c0c0' +COLORDICT['y'] = COLORDICT['yellow'] = '#ffff00' +COLORDICT['m'] = COLORDICT['magenta'] = '#ff00ff' +COLORDICT['c'] = COLORDICT['cyan'] = '#00ffff' +COLORDICT['darkBlue'] = '#000080' +COLORDICT['darkRed'] = '#800000' +COLORDICT['darkGreen'] = '#008000' +COLORDICT['darkBrown'] = '#660000' +COLORDICT['darkCyan'] = '#008080' +COLORDICT['darkYellow'] = '#808000' +COLORDICT['darkMagenta'] = '#800080' + + +def rgba(color, colorDict=None): + """Convert color code '#RRGGBB' and '#RRGGBBAA' to (R, G, B, A) + + It also convert RGB(A) values from uint8 to float in [0, 1] and + accept a QColor as color argument. + + :param str color: The color to convert + :param dict colorDict: A dictionary of color name conversion to color code + :returns: RGBA colors as floats in [0., 1.] + :rtype: tuple + """ + if colorDict is None: + colorDict = COLORDICT + + if hasattr(color, 'getRgbF'): # QColor support + color = color.getRgbF() + + values = numpy.asarray(color).ravel() + + if values.dtype.kind in 'iuf': # integer or float + # Color is an array + assert len(values) in (3, 4) + + # Convert from integers in [0, 255] to float in [0, 1] + if values.dtype.kind in 'iu': + values = values / 255. + + # Clip to [0, 1] + values[values < 0.] = 0. + values[values > 1.] = 1. + + if len(values) == 3: + return values[0], values[1], values[2], 1. + else: + return tuple(values) + + # We assume color is a string + if not color.startswith('#'): + color = colorDict[color] + + assert len(color) in (7, 9) and color[0] == '#' + r = int(color[1:3], 16) / 255. + g = int(color[3:5], 16) / 255. + b = int(color[5:7], 16) / 255. + a = int(color[7:9], 16) / 255. if len(color) == 9 else 1. + return r, g, b, a + + +_COLORMAP_CURSOR_COLORS = { + 'gray': 'pink', + 'reversed gray': 'pink', + 'temperature': 'pink', + 'red': 'green', + 'green': 'pink', + 'blue': 'yellow', + 'jet': 'pink', + 'viridis': 'pink', + 'magma': 'green', + 'inferno': 'green', + 'plasma': 'green', +} + + +def cursorColorForColormap(colormapName): + """Get a color suitable for overlay over a colormap. + + :param str colormapName: The name of the colormap. + :return: Name of the color. + :rtype: str + """ + return _COLORMAP_CURSOR_COLORS.get(colormapName, 'black') + + +_CMAPS = {} # Store additional colormaps + + +def getMPLColormap(name): + """Returns matplotlib colormap corresponding to given name + + :param str name: The name of the colormap + :return: The corresponding colormap + :rtype: matplolib.colors.Colormap + """ + if not _CMAPS: # Lazy initialization of own colormaps + cdict = {'red': ((0.0, 0.0, 0.0), + (1.0, 1.0, 1.0)), + 'green': ((0.0, 0.0, 0.0), + (1.0, 0.0, 0.0)), + 'blue': ((0.0, 0.0, 0.0), + (1.0, 0.0, 0.0))} + _CMAPS['red'] = matplotlib.colors.LinearSegmentedColormap( + 'red', cdict, 256) + + cdict = {'red': ((0.0, 0.0, 0.0), + (1.0, 0.0, 0.0)), + 'green': ((0.0, 0.0, 0.0), + (1.0, 1.0, 1.0)), + 'blue': ((0.0, 0.0, 0.0), + (1.0, 0.0, 0.0))} + _CMAPS['green'] = matplotlib.colors.LinearSegmentedColormap( + 'green', cdict, 256) + + cdict = {'red': ((0.0, 0.0, 0.0), + (1.0, 0.0, 0.0)), + 'green': ((0.0, 0.0, 0.0), + (1.0, 0.0, 0.0)), + 'blue': ((0.0, 0.0, 0.0), + (1.0, 1.0, 1.0))} + _CMAPS['blue'] = matplotlib.colors.LinearSegmentedColormap( + 'blue', cdict, 256) + + # Temperature as defined in spslut + cdict = {'red': ((0.0, 0.0, 0.0), + (0.5, 0.0, 0.0), + (0.75, 1.0, 1.0), + (1.0, 1.0, 1.0)), + 'green': ((0.0, 0.0, 0.0), + (0.25, 1.0, 1.0), + (0.75, 1.0, 1.0), + (1.0, 0.0, 0.0)), + 'blue': ((0.0, 1.0, 1.0), + (0.25, 1.0, 1.0), + (0.5, 0.0, 0.0), + (1.0, 0.0, 0.0))} + # but limited to 256 colors for a faster display (of the colorbar) + _CMAPS['temperature'] = \ + matplotlib.colors.LinearSegmentedColormap( + 'temperature', cdict, 256) + + # reversed gray + cdict = {'red': ((0.0, 1.0, 1.0), + (1.0, 0.0, 0.0)), + 'green': ((0.0, 1.0, 1.0), + (1.0, 0.0, 0.0)), + 'blue': ((0.0, 1.0, 1.0), + (1.0, 0.0, 0.0))} + + _CMAPS['reversed gray'] = \ + matplotlib.colors.LinearSegmentedColormap( + 'yerg', cdict, 256) + + if name in _CMAPS: + return _CMAPS[name] + elif hasattr(MPLColormap, name): # viridis and sister colormaps + return getattr(MPLColormap, name) + else: + # matplotlib built-in + return matplotlib.cm.get_cmap(name) + + +def getMPLScalarMappable(colormap, data=None): + """Returns matplotlib ScalarMappable corresponding to colormap + + :param dict colormap: The colormap to convert + :param numpy.ndarray data: + The data on which the colormap is applied. + If provided, it is used to compute autoscale. + :return: matplotlib object corresponding to colormap + :rtype: matplotlib.cm.ScalarMappable + """ + assert colormap is not None + + if colormap['name'] is not None: + cmap = getMPLColormap(colormap['name']) + + else: # No name, use custom colors + if 'colors' not in colormap: + raise ValueError( + 'addImage: colormap no name nor list of colors.') + colors = numpy.array(colormap['colors'], copy=True) + assert len(colors.shape) == 2 + assert colors.shape[-1] in (3, 4) + if colors.dtype == numpy.uint8: + # Convert to float in [0., 1.] + colors = colors.astype(numpy.float32) / 255. + cmap = matplotlib.colors.ListedColormap(colors) + + if colormap['normalization'].startswith('log'): + vmin, vmax = None, None + if not colormap['autoscale']: + if colormap['vmin'] > 0.: + vmin = colormap['vmin'] + if colormap['vmax'] > 0.: + vmax = colormap['vmax'] + + if vmin is None or vmax is None: + _logger.warning('Log colormap with negative bounds, ' + + 'changing bounds to positive ones.') + elif vmin > vmax: + _logger.warning('Colormap bounds are inverted.') + vmin, vmax = vmax, vmin + + # Set unset/negative bounds to positive bounds + if (vmin is None or vmax is None) and data is not None: + finiteData = data[numpy.isfinite(data)] + posData = finiteData[finiteData > 0] + if vmax is None: + # 1. as an ultimate fallback + vmax = posData.max() if posData.size > 0 else 1. + if vmin is None: + vmin = posData.min() if posData.size > 0 else vmax + if vmin > vmax: + vmin = vmax + + norm = matplotlib.colors.LogNorm(vmin, vmax) + + else: # Linear normalization + if colormap['autoscale']: + if data is None: + vmin, vmax = None, None + else: + finiteData = data[numpy.isfinite(data)] + vmin = finiteData.min() + vmax = finiteData.max() + else: + vmin = colormap['vmin'] + vmax = colormap['vmax'] + if vmin > vmax: + _logger.warning('Colormap bounds are inverted.') + vmin, vmax = vmax, vmin + + norm = matplotlib.colors.Normalize(vmin, vmax) + + return matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap) + + +def applyColormapToData(data, + name='gray', + normalization='linear', + autoscale=True, + vmin=0., + vmax=1., + colors=None): + """Apply a colormap to the data and returns the RGBA image + + This supports data of any dimensions (not only of dimension 2). + The returned array will have one more dimension (with 4 entries) + than the input data to store the RGBA channels + corresponding to each bin in the array. + + :param numpy.ndarray data: The data to convert. + :param str name: Name of the colormap (default: 'gray'). + :param str normalization: Colormap mapping: 'linear' or 'log'. + :param bool autoscale: Whether to use data min/max (True, default) + or [vmin, vmax] range (False). + :param float vmin: The minimum value of the range to use if + 'autoscale' is False. + :param float vmax: The maximum value of the range to use if + 'autoscale' is False. + :param numpy.ndarray colors: Only used if name is None. + Custom colormap colors as Nx3 or Nx4 RGB or RGBA arrays + :return: The computed RGBA image + :rtype: numpy.ndarray of uint8 + """ + # Debian 7 specific support + # No transparent colormap with matplotlib < 1.2.0 + # Add support for transparent colormap for uint8 data with + # colormap with 256 colors, linear norm, [0, 255] range + if matplotlib.__version__ < '1.2.0': + if name is None and colors is not None: + colors = numpy.array(colors, copy=False) + if (colors.shape[-1] == 4 and + not numpy.all(numpy.equal(colors[3], 255))): + # This is a transparent colormap + if (colors.shape == (256, 4) and + normalization == 'linear' and + not autoscale and + vmin == 0 and vmax == 255 and + data.dtype == numpy.uint8): + # Supported case, convert data to RGBA + return colors[data.reshape(-1)].reshape( + data.shape + (4,)) + else: + _logger.warning( + 'matplotlib %s does not support transparent ' + 'colormap.', matplotlib.__version__) + + colormap = dict(name=name, + normalization=normalization, + autoscale=autoscale, + vmin=vmin, + vmax=vmax, + colors=colors) + scalarMappable = getMPLScalarMappable(colormap, data) + rgbaImage = scalarMappable.to_rgba(data, bytes=True) + + return rgbaImage diff --git a/silx/gui/plot/CurvesROIWidget.py b/silx/gui/plot/CurvesROIWidget.py new file mode 100644 index 0000000..13c3de0 --- /dev/null +++ b/silx/gui/plot/CurvesROIWidget.py @@ -0,0 +1,975 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Widget to handle regions of interest (ROI) on curves displayed in a PlotWindow. + +This widget is meant to work with :class:`PlotWindow`. + +ROI are defined by : + +- A name (`ROI` column) +- A type. The type is the label of the x axis. + This can be used to apply or not some ROI to a curve and do some post processing. +- The x coordinate of the left limit (`from` column) +- The x coordinate of the right limit (`to` column) +- Raw counts: integral of the curve between the + min ROI point and the max ROI point to the y = 0 line + + .. image:: img/rawCounts.png + +- Net counts: the integral of the curve between the + min ROI point and the max ROI point to [ROI min point, ROI max point] segment + + .. image:: img/netCounts.png +""" + +__authors__ = ["V.A. Sole", "T. Vincent"] +__license__ = "MIT" +__date__ = "26/04/2017" + +from collections import OrderedDict + +import logging +import os +import sys + +import numpy + +from silx.io import dictdump +from .. import icons, qt + + +_logger = logging.getLogger(__name__) + + +class CurvesROIWidget(qt.QWidget): + """Widget displaying a table of ROI information. + + :param parent: See :class:`QWidget` + :param str name: The title of this widget + """ + + sigROIWidgetSignal = qt.Signal(object) + """Signal of ROIs modifications. + + Modification information if given as a dict with an 'event' key + providing the type of events. + + Type of events: + + - AddROI, DelROI, LoadROI and ResetROI with keys: 'roilist', 'roidict' + + - selectionChanged with keys: 'row', 'col' 'roi', 'key', 'colheader', + 'rowheader' + """ + + def __init__(self, parent=None, name=None): + super(CurvesROIWidget, self).__init__(parent) + if name is not None: + self.setWindowTitle(name) + layout = qt.QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + ############## + self.headerLabel = qt.QLabel(self) + self.headerLabel.setAlignment(qt.Qt.AlignHCenter) + self.setHeader() + layout.addWidget(self.headerLabel) + ############## + self.roiTable = ROITable(self) + rheight = self.roiTable.horizontalHeader().sizeHint().height() + self.roiTable.setMinimumHeight(4 * rheight) + self.fillFromROIDict = self.roiTable.fillFromROIDict + self.getROIListAndDict = self.roiTable.getROIListAndDict + layout.addWidget(self.roiTable) + self._roiFileDir = qt.QDir.home().absolutePath() + ################# + + hbox = qt.QWidget(self) + hboxlayout = qt.QHBoxLayout(hbox) + hboxlayout.setContentsMargins(0, 0, 0, 0) + hboxlayout.setSpacing(0) + + hboxlayout.addStretch(0) + + self.addButton = qt.QPushButton(hbox) + self.addButton.setText("Add ROI") + self.addButton.setToolTip('Create a new ROI') + self.delButton = qt.QPushButton(hbox) + self.delButton.setText("Delete ROI") + self.addButton.setToolTip('Remove the selected ROI') + self.resetButton = qt.QPushButton(hbox) + self.resetButton.setText("Reset") + self.addButton.setToolTip('Clear all created ROIs. We only let the default ROI') + + hboxlayout.addWidget(self.addButton) + hboxlayout.addWidget(self.delButton) + hboxlayout.addWidget(self.resetButton) + + hboxlayout.addStretch(0) + + self.loadButton = qt.QPushButton(hbox) + self.loadButton.setText("Load") + self.loadButton.setToolTip('Load ROIs from a .ini file') + self.saveButton = qt.QPushButton(hbox) + self.saveButton.setText("Save") + self.loadButton.setToolTip('Save ROIs to a .ini file') + hboxlayout.addWidget(self.loadButton) + hboxlayout.addWidget(self.saveButton) + layout.setStretchFactor(self.headerLabel, 0) + layout.setStretchFactor(self.roiTable, 1) + layout.setStretchFactor(hbox, 0) + + layout.addWidget(hbox) + + self.addButton.clicked.connect(self._add) + self.delButton.clicked.connect(self._del) + self.resetButton.clicked.connect(self._reset) + + self.loadButton.clicked.connect(self._load) + self.saveButton.clicked.connect(self._save) + self.roiTable.sigROITableSignal.connect(self._forward) + + @property + def roiFileDir(self): + """The directory from which to load/save ROI from/to files.""" + if not os.path.isdir(self._roiFileDir): + self._roiFileDir = qt.QDir.home().absolutePath() + return self._roiFileDir + + @roiFileDir.setter + def roiFileDir(self, roiFileDir): + self._roiFileDir = str(roiFileDir) + + def setRois(self, roidict, order=None): + """Set the ROIs by providing a dictionary of ROI information. + + The dictionary keys are the ROI names. + Each value is a sub-dictionary of ROI info with the following fields: + + - ``"from"``: x coordinate of the left limit, as a float + - ``"to"``: x coordinate of the right limit, as a float + - ``"type"``: type of ROI, as a string (e.g "channels", "energy") + + + :param roidict: Dictionary of ROIs + :param str order: Field used for ordering the ROIs. + One of "from", "to", "type". + None (default) for no ordering, or same order as specified + in parameter ``roidict`` if provided as an OrderedDict. + """ + if order is None or order.lower() == "none": + roilist = list(roidict.keys()) + else: + assert order in ["from", "to", "type"] + roilist = sorted(roidict.keys(), + key=lambda roi_name: roidict[roi_name].get(order)) + + return self.roiTable.fillFromROIDict(roilist, roidict) + + def getRois(self, order=None): + """Return the currently defined ROIs, as an ordered dict. + + The dictionary keys are the ROI names. + Each value is a sub-dictionary of ROI info with the following fields: + + - ``"from"``: x coordinate of the left limit, as a float + - ``"to"``: x coordinate of the right limit, as a float + - ``"type"``: type of ROI, as a string (e.g "channels", "energy") + :param order: Field used for ordering the ROIs. + One of "from", "to", "type", "netcounts", "rawcounts". + None (default) to get the same order as displayed in the widget. + :return: Ordered dictionary of ROI information + """ + roilist, roidict = self.roiTable.getROIListAndDict() + if order is None or order.lower() == "none": + ordered_roilist = roilist + else: + assert order in ["from", "to", "type", "netcounts", "rawcounts"] + ordered_roilist = sorted(roidict.keys(), + key=lambda roi_name: roidict[roi_name].get(order)) + + return OrderedDict([(name, roidict[name]) for name in ordered_roilist]) + + def _add(self): + """Add button clicked handler""" + ddict = {} + ddict['event'] = "AddROI" + roilist, roidict = self.roiTable.getROIListAndDict() + ddict['roilist'] = roilist + ddict['roidict'] = roidict + self.sigROIWidgetSignal.emit(ddict) + + def _del(self): + """Delete button clicked handler""" + row = self.roiTable.currentRow() + if row >= 0: + index = self.roiTable.labels.index('Type') + text = str(self.roiTable.item(row, index).text()) + if text.upper() != 'DEFAULT': + index = self.roiTable.labels.index('ROI') + key = str(self.roiTable.item(row, index).text()) + else: + # This is to prevent deleting ICR ROI, that is + # usually initialized as "Default" type. + return + roilist, roidict = self.roiTable.getROIListAndDict() + row = roilist.index(key) + del roilist[row] + del roidict[key] + if len(roilist) > 0: + currentroi = roilist[0] + else: + currentroi = None + + self.roiTable.fillFromROIDict(roilist=roilist, + roidict=roidict, + currentroi=currentroi) + ddict = {} + ddict['event'] = "DelROI" + ddict['roilist'] = roilist + ddict['roidict'] = roidict + self.sigROIWidgetSignal.emit(ddict) + + def _forward(self, ddict): + """Broadcast events from ROITable signal""" + self.sigROIWidgetSignal.emit(ddict) + + def _reset(self): + """Reset button clicked handler""" + ddict = {} + ddict['event'] = "ResetROI" + roilist0, roidict0 = self.roiTable.getROIListAndDict() + index = 0 + for key in roilist0: + if roidict0[key]['type'].upper() == 'DEFAULT': + index = roilist0.index(key) + break + roilist = [] + roidict = {} + if len(roilist0): + roilist.append(roilist0[index]) + roidict[roilist[0]] = {} + roidict[roilist[0]].update(roidict0[roilist[0]]) + self.roiTable.fillFromROIDict(roilist=roilist, roidict=roidict) + ddict['roilist'] = roilist + ddict['roidict'] = roidict + self.sigROIWidgetSignal.emit(ddict) + + def _load(self): + """Load button clicked handler""" + dialog = qt.QFileDialog(self) + dialog.setNameFilters( + ['INI File *.ini', 'JSON File *.json', 'All *.*']) + dialog.setFileMode(qt.QFileDialog.ExistingFile) + dialog.setDirectory(self.roiFileDir) + if not dialog.exec_(): + dialog.close() + return + + # pyflakes bug http://bugs.debian.org/cgi-bin/bugreport.cgi?bug=666494 + outputFile = dialog.selectedFiles()[0] + dialog.close() + + self.roiFileDir = os.path.dirname(outputFile) + self.load(outputFile) + + def load(self, filename): + """Load ROI widget information from a file storing a dict of ROI. + + :param str filename: The file from which to load ROI + """ + rois = dictdump.load(filename) + currentROI = None + if self.roiTable.rowCount(): + item = self.roiTable.item(self.roiTable.currentRow(), 0) + if item is not None: + currentROI = str(item.text()) + + # Remove rawcounts and netcounts from ROIs + for roi in rois['ROI']['roidict'].values(): + roi.pop('rawcounts', None) + roi.pop('netcounts', None) + + self.roiTable.fillFromROIDict(roilist=rois['ROI']['roilist'], + roidict=rois['ROI']['roidict'], + currentroi=currentROI) + + roilist, roidict = self.roiTable.getROIListAndDict() + event = {'event': 'LoadROI', 'roilist': roilist, 'roidict': roidict} + self.sigROIWidgetSignal.emit(event) + + def _save(self): + """Save button clicked handler""" + dialog = qt.QFileDialog(self) + dialog.setNameFilters(['INI File *.ini', 'JSON File *.json']) + dialog.setFileMode(qt.QFileDialog.AnyFile) + dialog.setAcceptMode(qt.QFileDialog.AcceptSave) + dialog.setDirectory(self.roiFileDir) + if not dialog.exec_(): + dialog.close() + return + + outputFile = dialog.selectedFiles()[0] + extension = '.' + dialog.selectedNameFilter().split('.')[-1] + dialog.close() + + if not outputFile.endswith(extension): + outputFile += extension + + if os.path.exists(outputFile): + try: + os.remove(outputFile) + except IOError: + msg = qt.QMessageBox(self) + msg.setIcon(qt.QMessageBox.Critical) + msg.setText("Input Output Error: %s" % (sys.exc_info()[1])) + msg.exec_() + return + self.roiFileDir = os.path.dirname(outputFile) + self.save(outputFile) + + def save(self, filename): + """Save current ROIs of the widget as a dict of ROI to a file. + + :param str filename: The file to which to save the ROIs + """ + roilist, roidict = self.roiTable.getROIListAndDict() + datadict = {'ROI': {'roilist': roilist, 'roidict': roidict}} + dictdump.dump(datadict, filename) + + def setHeader(self, text='ROIs'): + """Set the header text of this widget""" + self.headerLabel.setText("<b>%s<\b>" % text) + + +class ROITable(qt.QTableWidget): + """Table widget displaying ROI information. + + See :class:`QTableWidget` for constructor arguments. + """ + + sigROITableSignal = qt.Signal(object) + """Signal of ROI table modifications. + """ + + def __init__(self, *args, **kwargs): + super(ROITable, self).__init__(*args, **kwargs) + self.setRowCount(1) + self.labels = 'ROI', 'Type', 'From', 'To', 'Raw Counts', 'Net Counts' + self.setColumnCount(len(self.labels)) + self.setSortingEnabled(False) + + for index, label in enumerate(self.labels): + item = self.horizontalHeaderItem(index) + if item is None: + item = qt.QTableWidgetItem(label, + qt.QTableWidgetItem.Type) + item.setText(label) + self.setHorizontalHeaderItem(index, item) + + self.roidict = {} + self.roilist = [] + + self.building = False + self.fillFromROIDict(roilist=self.roilist, roidict=self.roidict) + + self.cellClicked[(int, int)].connect(self._cellClickedSlot) + self.cellChanged[(int, int)].connect(self._cellChangedSlot) + verticalHeader = self.verticalHeader() + verticalHeader.sectionClicked[int].connect(self._rowChangedSlot) + + self.__setTooltip() + + def __setTooltip(self): + assert(self.labels[0] == 'ROI') + self.horizontalHeaderItem(0).setToolTip('Region of interest identifier') + assert(self.labels[1] == 'Type') + self.horizontalHeaderItem(1).setToolTip('Type of the ROI') + assert(self.labels[2] == 'From') + self.horizontalHeaderItem(2).setToolTip('X-value of the min point') + assert(self.labels[3] == 'To') + self.horizontalHeaderItem(3).setToolTip('X-value of the max point') + assert(self.labels[4] == 'Raw Counts') + self.horizontalHeaderItem(4).setToolTip('Estimation of the integral \ + between y=0 and the selected curve') + assert(self.labels[5] == 'Net Counts') + self.horizontalHeaderItem(5).setToolTip('Estimation of the integral \ + between the segment [maxPt, minPt] and the selected curve') + + def fillFromROIDict(self, roilist=(), roidict=None, currentroi=None): + """Set the ROIs by providing a list of ROI names and a dictionary + of ROI information for each ROI. + + The ROI names must match an existing dictionary key. + The name list is used to provide an order for the ROIs. + + The dictionary's values are sub-dictionaries containing 3 + mandatory fields: + + - ``"from"``: x coordinate of the left limit, as a float + - ``"to"``: x coordinate of the right limit, as a float + - ``"type"``: type of ROI, as a string (e.g "channels", "energy") + + :param roilist: List of ROI names (keys of roidict) + :type roilist: List + :param dict roidict: Dict of ROI information + :param currentroi: Name of the selected ROI or None (no selection) + """ + if roidict is None: + roidict = {} + + self.building = True + line0 = 0 + self.roilist = [] + self.roidict = {} + for key in roilist: + if key in roidict.keys(): + roi = roidict[key] + self.roilist.append(key) + self.roidict[key] = {} + self.roidict[key].update(roi) + line0 = line0 + 1 + nlines = self.rowCount() + if (line0 > nlines): + self.setRowCount(line0) + line = line0 - 1 + self.roidict[key]['line'] = line + ROI = key + roitype = "%s" % roi['type'] + fromdata = "%6g" % (roi['from']) + todata = "%6g" % (roi['to']) + if 'rawcounts' in roi: + rawcounts = "%6g" % (roi['rawcounts']) + else: + rawcounts = " ?????? " + if 'netcounts' in roi: + netcounts = "%6g" % (roi['netcounts']) + else: + netcounts = " ?????? " + fields = [ROI, roitype, fromdata, todata, rawcounts, netcounts] + col = 0 + for field in fields: + key2 = self.item(line, col) + if key2 is None: + key2 = qt.QTableWidgetItem(field, + qt.QTableWidgetItem.Type) + self.setItem(line, col, key2) + else: + key2.setText(field) + if (ROI.upper() == 'ICR') or (ROI.upper() == 'DEFAULT'): + key2.setFlags(qt.Qt.ItemIsSelectable | + qt.Qt.ItemIsEnabled) + else: + if col in [0, 2, 3]: + key2.setFlags(qt.Qt.ItemIsSelectable | + qt.Qt.ItemIsEnabled | + qt.Qt.ItemIsEditable) + else: + key2.setFlags(qt.Qt.ItemIsSelectable | + qt.Qt.ItemIsEnabled) + col = col + 1 + self.setRowCount(line0) + i = 0 + for _label in self.labels: + self.resizeColumnToContents(i) + i = i + 1 + self.sortByColumn(2, qt.Qt.AscendingOrder) + for i in range(len(self.roilist)): + key = str(self.item(i, 0).text()) + self.roilist[i] = key + self.roidict[key]['line'] = i + if len(self.roilist) == 1: + self.selectRow(0) + else: + if currentroi in self.roidict.keys(): + self.selectRow(self.roidict[currentroi]['line']) + _logger.debug("Qt4 ensureCellVisible to be implemented") + self.building = False + + def getROIListAndDict(self): + """Return the currently defined ROIs, as a 2-tuple + ``(roiList, roiDict)`` + + ``roiList`` is a list of ROI names. + ``roiDict`` is a dictionary of ROI info. + + The ROI names must match an existing dictionary key. + The name list is used to provide an order for the ROIs. + + The dictionary's values are sub-dictionaries containing 3 + fields: + + - ``"from"``: x coordinate of the left limit, as a float + - ``"to"``: x coordinate of the right limit, as a float + - ``"type"``: type of ROI, as a string (e.g "channels", "energy") + + + :return: ordered dict as a tuple of (list of ROI names, dict of info) + """ + return self.roilist, self.roidict + + def _cellClickedSlot(self, *var, **kw): + # selection changed event, get the current selection + row = self.currentRow() + col = self.currentColumn() + if row >= 0 and row < len(self.roilist): + item = self.item(row, 0) + text = '' if item is None else str(item.text()) + self.roilist[row] = text + self._emitSelectionChangedSignal(row, col) + + def _rowChangedSlot(self, row): + self._emitSelectionChangedSignal(row, 0) + + def _cellChangedSlot(self, row, col): + _logger.debug("_cellChangedSlot(%d, %d)", row, col) + if self.building: + return + if col == 0: + self.nameSlot(row, col) + else: + self._valueChanged(row, col) + + def _valueChanged(self, row, col): + if col not in [2, 3]: + return + item = self.item(row, col) + if item is None: + return + text = str(item.text()) + try: + value = float(text) + except: + return + if row >= len(self.roilist): + _logger.debug("deleting???") + return + item = self.item(row, 0) + if item is None: + text = "" + else: + text = str(item.text()) + if not len(text): + return + if col == 2: + self.roidict[text]['from'] = value + elif col == 3: + self.roidict[text]['to'] = value + self._emitSelectionChangedSignal(row, col) + + def nameSlot(self, row, col): + if col != 0: + return + if row >= len(self.roilist): + _logger.debug("deleting???") + return + item = self.item(row, col) + if item is None: + text = "" + else: + text = str(item.text()) + if len(text) and (text not in self.roilist): + old = self.roilist[row] + self.roilist[row] = text + self.roidict[text] = {} + self.roidict[text].update(self.roidict[old]) + del self.roidict[old] + self._emitSelectionChangedSignal(row, col) + + def _emitSelectionChangedSignal(self, row, col): + ddict = {} + ddict['event'] = "selectionChanged" + ddict['row'] = row + ddict['col'] = col + ddict['roi'] = self.roidict[self.roilist[row]] + ddict['key'] = self.roilist[row] + ddict['colheader'] = self.labels[col] + ddict['rowheader'] = "%d" % row + self.sigROITableSignal.emit(ddict) + + +class CurvesROIDockWidget(qt.QDockWidget): + """QDockWidget with a :class:`CurvesROIWidget` connected to a PlotWindow. + + It makes the link between the :class:`CurvesROIWidget` and the PlotWindow. + + :param parent: See :class:`QDockWidget` + :param plot: :class:`.PlotWindow` instance on which to operate + :param name: See :class:`QDockWidget` + """ + sigROISignal = qt.Signal(object) + + def __init__(self, parent=None, plot=None, name=None): + super(CurvesROIDockWidget, self).__init__(name, parent) + + assert plot is not None + self.plot = plot + + self.currentROI = None + self._middleROIMarkerFlag = False + + self._isConnected = False # True if connected to plot signals + self._isInit = False + + self.roiWidget = CurvesROIWidget(self, name) + """Main widget of type :class:`CurvesROIWidget`""" + + # convenience methods to offer a simpler API allowing to ignore + # the details of the underlying implementation + self.calculateROIs = self.calculateRois + self.setRois = self.roiWidget.setRois + self.getRois = self.roiWidget.getRois + + self.layout().setContentsMargins(0, 0, 0, 0) + self.setWidget(self.roiWidget) + + self.visibilityChanged.connect(self._visibilityChangedHandler) + + def toggleViewAction(self): + """Returns a checkable action that shows or closes this widget. + + See :class:`QMainWindow`. + """ + action = super(CurvesROIDockWidget, self).toggleViewAction() + action.setIcon(icons.getQIcon('plot-roi')) + return action + + def _visibilityChangedHandler(self, visible): + """Handle widget's visibilty updates. + + It is connected to plot signals only when visible. + """ + if visible: + if not self._isInit: + # Deferred ROI widget init finalization + self._isInit = True + self.roiWidget.sigROIWidgetSignal.connect(self._roiSignal) + # initialize with the ICR + self._roiSignal({'event': "AddROI"}) + + if not self._isConnected: + self.plot.sigPlotSignal.connect(self._handleROIMarkerEvent) + self.plot.sigActiveCurveChanged.connect( + self._activeCurveChanged) + self._isConnected = True + + self.calculateROIs() + else: + if self._isConnected: + self.plot.sigPlotSignal.disconnect(self._handleROIMarkerEvent) + self.plot.sigActiveCurveChanged.disconnect( + self._activeCurveChanged) + self._isConnected = False + + def _handleROIMarkerEvent(self, ddict): + """Handle plot signals related to marker events.""" + if ddict['event'] == 'markerMoved': + + label = ddict['label'] + if label not in ['ROI min', 'ROI max', 'ROI middle']: + return + + roiList, roiDict = self.roiWidget.getROIListAndDict() + if self.currentROI is None: + return + if self.currentROI not in roiDict: + return + x = ddict['x'] + + if label == 'ROI min': + roiDict[self.currentROI]['from'] = x + if self._middleROIMarkerFlag: + pos = 0.5 * (roiDict[self.currentROI]['to'] + + roiDict[self.currentROI]['from']) + self.plot.addXMarker(pos, + legend='ROI middle', + text='', + color='yellow', + draggable=True) + elif label == 'ROI max': + roiDict[self.currentROI]['to'] = x + if self._middleROIMarkerFlag: + pos = 0.5 * (roiDict[self.currentROI]['to'] + + roiDict[self.currentROI]['from']) + self.plot.addXMarker(pos, + legend='ROI middle', + text='', + color='yellow', + draggable=True) + elif label == 'ROI middle': + delta = x - 0.5 * (roiDict[self.currentROI]['from'] + + roiDict[self.currentROI]['to']) + roiDict[self.currentROI]['from'] += delta + roiDict[self.currentROI]['to'] += delta + self.plot.addXMarker(roiDict[self.currentROI]['from'], + legend='ROI min', + text='ROI min', + color='blue', + draggable=True) + self.plot.addXMarker(roiDict[self.currentROI]['to'], + legend='ROI max', + text='ROI max', + color='blue', + draggable=True) + else: + return + self.calculateROIs(roiList, roiDict) + self._emitCurrentROISignal() + + def _roiSignal(self, ddict): + """Handle ROI widget signal""" + _logger.debug("PlotWindow._roiSignal %s", str(ddict)) + if ddict['event'] == "AddROI": + xmin, xmax = self.plot.getGraphXLimits() + fromdata = xmin + 0.25 * (xmax - xmin) + todata = xmin + 0.75 * (xmax - xmin) + self.plot.remove('ROI min', kind='marker') + self.plot.remove('ROI max', kind='marker') + if self._middleROIMarkerFlag: + self.remove('ROI middle', kind='marker') + roiList, roiDict = self.roiWidget.getROIListAndDict() + nrois = len(roiList) + if nrois == 0: + newroi = "ICR" + fromdata, dummy0, todata, dummy1 = self._getAllLimits() + draggable = False + color = 'black' + else: + for i in range(nrois): + i += 1 + newroi = "newroi %d" % i + if newroi not in roiList: + break + color = 'blue' + draggable = True + self.plot.addXMarker(fromdata, + legend='ROI min', + text='ROI min', + color=color, + draggable=draggable) + self.plot.addXMarker(todata, + legend='ROI max', + text='ROI max', + color=color, + draggable=draggable) + if draggable and self._middleROIMarkerFlag: + pos = 0.5 * (fromdata + todata) + self.plot.addXMarker(pos, + legend='ROI middle', + text="", + color='yellow', + draggable=draggable) + roiList.append(newroi) + roiDict[newroi] = {} + if newroi == "ICR": + roiDict[newroi]['type'] = "Default" + else: + roiDict[newroi]['type'] = self.plot.getGraphXLabel() + roiDict[newroi]['from'] = fromdata + roiDict[newroi]['to'] = todata + self.roiWidget.fillFromROIDict(roilist=roiList, + roidict=roiDict, + currentroi=newroi) + self.currentROI = newroi + self.calculateROIs() + elif ddict['event'] in ['DelROI', "ResetROI"]: + self.plot.remove('ROI min', kind='marker') + self.plot.remove('ROI max', kind='marker') + if self._middleROIMarkerFlag: + self.plot.remove('ROI middle', kind='marker') + roiList, roiDict = self.roiWidget.getROIListAndDict() + roiDictKeys = list(roiDict.keys()) + if len(roiDictKeys): + currentroi = roiDictKeys[0] + else: + # create again the ICR + ddict = {"event": "AddROI"} + return self._roiSignal(ddict) + + self.roiWidget.fillFromROIDict(roilist=roiList, + roidict=roiDict, + currentroi=currentroi) + self.currentROI = currentroi + + elif ddict['event'] == 'LoadROI': + self.calculateROIs() + + elif ddict['event'] == 'selectionChanged': + _logger.debug("Selection changed") + self.roilist, self.roidict = self.roiWidget.getROIListAndDict() + fromdata = ddict['roi']['from'] + todata = ddict['roi']['to'] + self.plot.remove('ROI min', kind='marker') + self.plot.remove('ROI max', kind='marker') + if ddict['key'] == 'ICR': + draggable = False + color = 'black' + else: + draggable = True + color = 'blue' + self.plot.addXMarker(fromdata, + legend='ROI min', + text='ROI min', + color=color, + draggable=draggable) + self.plot.addXMarker(todata, + legend='ROI max', + text='ROI max', + color=color, + draggable=draggable) + if draggable and self._middleROIMarkerFlag: + pos = 0.5 * (fromdata + todata) + self.plot.addXMarker(pos, + legend='ROI middle', + text="", + color='yellow', + draggable=True) + self.currentROI = ddict['key'] + if ddict['colheader'] in ['From', 'To']: + dict0 = {} + dict0['event'] = "SetActiveCurveEvent" + dict0['legend'] = self.plot.getActiveCurve(just_legend=1) + self.plot.setActiveCurve(dict0['legend']) + elif ddict['colheader'] == 'Raw Counts': + pass + elif ddict['colheader'] == 'Net Counts': + pass + else: + self._emitCurrentROISignal() + + else: + _logger.debug("Unknown or ignored event %s", ddict['event']) + + def _activeCurveChanged(self, *args): + """Recompute ROIs when active curve changed.""" + self.calculateROIs() + + def calculateRois(self, roiList=None, roiDict=None): + """Compute ROI information""" + if roiList is None or roiDict is None: + roiList, roiDict = self.roiWidget.getROIListAndDict() + + activeCurve = self.plot.getActiveCurve(just_legend=False) + if activeCurve is None: + xproc = None + yproc = None + self.roiWidget.setHeader() + else: + x = activeCurve.getXData(copy=False) + y = activeCurve.getYData(copy=False) + legend = activeCurve.getLegend() + idx = numpy.argsort(x, kind='mergesort') + xproc = numpy.take(x, idx) + yproc = numpy.take(y, idx) + self.roiWidget.setHeader('ROIs of %s' % legend) + + for key in roiList: + if key == 'ICR': + if xproc is not None: + roiDict[key]['from'] = xproc.min() + roiDict[key]['to'] = xproc.max() + else: + roiDict[key]['from'] = 0 + roiDict[key]['to'] = -1 + fromData = roiDict[key]['from'] + toData = roiDict[key]['to'] + if xproc is not None: + idx = numpy.nonzero((fromData <= xproc) & + (xproc <= toData))[0] + if len(idx): + xw = xproc[idx] + yw = yproc[idx] + rawCounts = yw.sum(dtype=numpy.float) + deltaX = xw[-1] - xw[0] + deltaY = yw[-1] - yw[0] + if deltaX > 0.0: + slope = (deltaY / deltaX) + background = yw[0] + slope * (xw - xw[0]) + netCounts = (rawCounts - + background.sum(dtype=numpy.float)) + else: + netCounts = 0.0 + else: + rawCounts = 0.0 + netCounts = 0.0 + roiDict[key]['rawcounts'] = rawCounts + roiDict[key]['netcounts'] = netCounts + else: + roiDict[key].pop('rawcounts', None) + roiDict[key].pop('netcounts', None) + + self.roiWidget.fillFromROIDict( + roilist=roiList, + roidict=roiDict, + currentroi=self.currentROI if self.currentROI in roiList else None) + + def _emitCurrentROISignal(self): + ddict = {} + ddict['event'] = "currentROISignal" + _roiList, roiDict = self.roiWidget.getROIListAndDict() + if self.currentROI in roiDict: + ddict['ROI'] = roiDict[self.currentROI] + else: + self.currentROI = None + ddict['current'] = self.currentROI + self.sigROISignal.emit(ddict) + + def _getAllLimits(self): + """Retrieve the limits based on the curves.""" + curves = self.plot.getAllCurves() + if not curves: + return 1.0, 1.0, 100., 100. + + xmin, ymin = None, None + xmax, ymax = None, None + + for curve in curves: + x = curve.getXData(copy=False) + y = curve.getYData(copy=False) + if xmin is None: + xmin = x.min() + else: + xmin = min(xmin, x.min()) + if xmax is None: + xmax = x.max() + else: + xmax = max(xmax, x.max()) + if ymin is None: + ymin = y.min() + else: + ymin = min(ymin, y.min()) + if ymax is None: + ymax = y.max() + else: + ymax = max(ymax, y.max()) + + return xmin, ymin, xmax, ymax + + def showEvent(self, event): + """Make sure this widget is raised when it is shown + (when it is first created as a tab in PlotWindow or when it is shown + again after hiding). + """ + self.raise_() diff --git a/silx/gui/plot/ImageView.py b/silx/gui/plot/ImageView.py new file mode 100644 index 0000000..780215e --- /dev/null +++ b/silx/gui/plot/ImageView.py @@ -0,0 +1,860 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""QWidget displaying a 2D image with histograms on its sides. + +The :class:`ImageView` implements this widget, and +:class:`ImageViewMainWindow` provides a main window with additional toolbar +and status bar. + +Basic usage of :class:`ImageView` is through the following methods: + +- :meth:`ImageView.getColormap`, :meth:`ImageView.setColormap` to update the + default colormap to use and update the currently displayed image. +- :meth:`ImageView.setImage` to update the displayed image. + +The :class:`ImageView` uses :class:`PlotWindow` and also +exposes :class:`silx.gui.plot.Plot` API for further control +(plot title, axes labels, adding other images, ...). + +For an example of use, see the implementation of :class:`ImageViewMainWindow`, +and `example/imageview.py`. +""" + +from __future__ import division + + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "13/10/2016" + + +import logging +import numpy + +from .. import qt + +from . import items, PlotWindow, PlotWidget, PlotActions +from .Colors import cursorColorForColormap +from .PlotTools import LimitsToolBar +from .Profile import ProfileToolBar + + +_logger = logging.getLogger(__name__) + + +# RadarView ################################################################### + +class RadarView(qt.QGraphicsView): + """Widget presenting a synthetic view of a 2D area and + the current visible area. + + Coordinates are as in QGraphicsView: + x goes from left to right and y goes from top to bottom. + This widget preserves the aspect ratio of the areas. + + The 2D area and the visible area can be set with :meth:`setDataRect` + and :meth:`setVisibleRect`. + When the visible area has been dragged by the user, its new position + is signaled by the *visibleRectDragged* signal. + + It is possible to invert the direction of the axes by using the + :meth:`scale` method of QGraphicsView. + """ + + visibleRectDragged = qt.Signal(float, float, float, float) + """Signals that the visible rectangle has been dragged. + + It provides: left, top, width, height in data coordinates. + """ + + _DATA_PEN = qt.QPen(qt.QColor('white')) + _DATA_BRUSH = qt.QBrush(qt.QColor('light gray')) + _VISIBLE_PEN = qt.QPen(qt.QColor('red')) + _VISIBLE_PEN.setWidth(2) + _VISIBLE_PEN.setCosmetic(True) + _VISIBLE_BRUSH = qt.QBrush(qt.QColor(0, 0, 0, 0)) + _TOOLTIP = 'Radar View:\nRed contour: Visible area\nGray area: The image' + + _PIXMAP_SIZE = 256 + + class _DraggableRectItem(qt.QGraphicsRectItem): + """RectItem which signals its change through visibleRectDragged.""" + def __init__(self, *args, **kwargs): + super(RadarView._DraggableRectItem, self).__init__( + *args, **kwargs) + + self._previousCursor = None + self.setFlag(qt.QGraphicsItem.ItemIsMovable) + self.setFlag(qt.QGraphicsItem.ItemSendsGeometryChanges) + self.setAcceptHoverEvents(True) + self._ignoreChange = False + self._constraint = 0, 0, 0, 0 + + def setConstraintRect(self, left, top, width, height): + """Set the constraint rectangle for dragging. + + The coordinates are in the _DraggableRectItem coordinate system. + + This constraint only applies to modification through interaction + (i.e., this constraint is not applied to change through API). + + If the _DraggableRectItem is smaller than the constraint rectangle, + the _DraggableRectItem remains within the constraint rectangle. + If the _DraggableRectItem is wider than the constraint rectangle, + the constraint rectangle remains within the _DraggableRectItem. + """ + self._constraint = left, left + width, top, top + height + + def setPos(self, *args, **kwargs): + """Overridden to ignore changes from API in itemChange.""" + self._ignoreChange = True + super(RadarView._DraggableRectItem, self).setPos(*args, **kwargs) + self._ignoreChange = False + + def moveBy(self, *args, **kwargs): + """Overridden to ignore changes from API in itemChange.""" + self._ignoreChange = True + super(RadarView._DraggableRectItem, self).moveBy(*args, **kwargs) + self._ignoreChange = False + + def itemChange(self, change, value): + """Callback called before applying changes to the item.""" + if (change == qt.QGraphicsItem.ItemPositionChange and + not self._ignoreChange): + # Makes sure that the visible area is in the data + # or that data is in the visible area if area is too wide + x, y = value.x(), value.y() + xMin, xMax, yMin, yMax = self._constraint + + if self.rect().width() <= (xMax - xMin): + if x < xMin: + value.setX(xMin) + elif x > xMax - self.rect().width(): + value.setX(xMax - self.rect().width()) + else: + if x > xMin: + value.setX(xMin) + elif x < xMax - self.rect().width(): + value.setX(xMax - self.rect().width()) + + if self.rect().height() <= (yMax - yMin): + if y < yMin: + value.setY(yMin) + elif y > yMax - self.rect().height(): + value.setY(yMax - self.rect().height()) + else: + if y > yMin: + value.setY(yMin) + elif y < yMax - self.rect().height(): + value.setY(yMax - self.rect().height()) + + if self.pos() != value: + # Notify change through signal + views = self.scene().views() + assert len(views) == 1 + views[0].visibleRectDragged.emit( + value.x() + self.rect().left(), + value.y() + self.rect().top(), + self.rect().width(), + self.rect().height()) + + return value + + return super(RadarView._DraggableRectItem, self).itemChange( + change, value) + + def hoverEnterEvent(self, event): + """Called when the mouse enters the rectangle area""" + self._previousCursor = self.cursor() + self.setCursor(qt.Qt.OpenHandCursor) + + def hoverLeaveEvent(self, event): + """Called when the mouse leaves the rectangle area""" + if self._previousCursor is not None: + self.setCursor(self._previousCursor) + self._previousCursor = None + + def __init__(self, parent=None): + self._scene = qt.QGraphicsScene() + self._dataRect = self._scene.addRect(0, 0, 1, 1, + self._DATA_PEN, + self._DATA_BRUSH) + self._visibleRect = self._DraggableRectItem(0, 0, 1, 1) + self._visibleRect.setPen(self._VISIBLE_PEN) + self._visibleRect.setBrush(self._VISIBLE_BRUSH) + self._scene.addItem(self._visibleRect) + + super(RadarView, self).__init__(self._scene, parent) + self.setHorizontalScrollBarPolicy(qt.Qt.ScrollBarAlwaysOff) + self.setVerticalScrollBarPolicy(qt.Qt.ScrollBarAlwaysOff) + self.setFocusPolicy(qt.Qt.NoFocus) + self.setStyleSheet('border: 0px') + self.setToolTip(self._TOOLTIP) + + def sizeHint(self): + # """Overridden to avoid sizeHint to depend on content size.""" + return self.minimumSizeHint() + + def wheelEvent(self, event): + # """Overridden to disable vertical scrolling with wheel.""" + event.ignore() + + def resizeEvent(self, event): + # """Overridden to fit current content to new size.""" + self.fitInView(self._scene.itemsBoundingRect(), qt.Qt.KeepAspectRatio) + super(RadarView, self).resizeEvent(event) + + def setDataRect(self, left, top, width, height): + """Set the bounds of the data rectangular area. + + This sets the coordinate system. + """ + self._dataRect.setRect(left, top, width, height) + self._visibleRect.setConstraintRect(left, top, width, height) + self.fitInView(self._scene.itemsBoundingRect(), qt.Qt.KeepAspectRatio) + + def setVisibleRect(self, left, top, width, height): + """Set the visible rectangular area. + + The coordinates are relative to the data rect. + """ + self._visibleRect.setRect(0, 0, width, height) + self._visibleRect.setPos(left, top) + self.fitInView(self._scene.itemsBoundingRect(), qt.Qt.KeepAspectRatio) + + +# ImageView ################################################################### + +class ImageView(PlotWindow): + """Display a single image with horizontal and vertical histograms. + + Use :meth:`setImage` to control the displayed image. + This class also provides the :class:`silx.gui.plot.Plot` API. + + :param parent: The parent of this widget or None. + :param backend: The backend to use for the plot (default: matplotlib). + See :class:`.Plot` for the list of supported backend. + :type backend: str or :class:`BackendBase.BackendBase` + """ + + HISTOGRAMS_COLOR = 'blue' + """Color to use for the side histograms.""" + + HISTOGRAMS_HEIGHT = 200 + """Height in pixels of the side histograms.""" + + IMAGE_MIN_SIZE = 200 + """Minimum size in pixels of the image area.""" + + # Qt signals + valueChanged = qt.Signal(float, float, float) + """Signals that the data value under the cursor has changed. + + It provides: row, column, data value. + + When the cursor is over an histogram, either row or column is Nan + and the provided data value is the histogram value + (i.e., the sum along the corresponding row/column). + Row and columns are either Nan or integer values. + """ + + def __init__(self, parent=None, backend=None): + self._imageLegend = '__ImageView__image' + str(id(self)) + self._cache = None # Store currently visible data information + self._updatingLimits = False + + super(ImageView, self).__init__(parent=parent, backend=backend, + resetzoom=True, autoScale=False, + logScale=False, grid=False, + curveStyle=False, colormap=True, + aspectRatio=True, yInverted=True, + copy=True, save=True, print_=True, + control=False, position=False, + roi=False, mask=True) + if parent is None: + self.setWindowTitle('ImageView') + + self._initWidgets(backend) + + self.profile = ProfileToolBar(plot=self) + """"Profile tools attached to this plot. + + See :class:`silx.gui.plot.PlotTools.ProfileToolBar` + """ + + self.addToolBar(self.profile) + + # Sync PlotBackend and ImageView + self._updateYAxisInverted() + + def _initWidgets(self, backend): + """Set-up layout and plots.""" + # Monkey-patch for histogram size + # alternative: create a layout that does not use widget size hints + def sizeHint(): + return qt.QSize(self.HISTOGRAMS_HEIGHT, self.HISTOGRAMS_HEIGHT) + + self._histoHPlot = PlotWidget(backend=backend) + self._histoHPlot.setInteractiveMode('zoom') + self._histoHPlot.setCallback(self._histoHPlotCB) + self._histoHPlot.getWidgetHandle().sizeHint = sizeHint + self._histoHPlot.getWidgetHandle().minimumSizeHint = sizeHint + + self.setPanWithArrowKeys(True) + + self.setInteractiveMode('zoom') # Color set in setColormap + self.sigPlotSignal.connect(self._imagePlotCB) + self.sigSetYAxisInverted.connect(self._updateYAxisInverted) + self.sigActiveImageChanged.connect(self._activeImageChangedSlot) + + self._histoVPlot = PlotWidget(backend=backend) + self._histoVPlot.setInteractiveMode('zoom') + self._histoVPlot.setCallback(self._histoVPlotCB) + self._histoVPlot.getWidgetHandle().sizeHint = sizeHint + self._histoVPlot.getWidgetHandle().minimumSizeHint = sizeHint + + self._radarView = RadarView() + self._radarView.visibleRectDragged.connect(self._radarViewCB) + + self._layout = qt.QGridLayout() + self._layout.addWidget(self.getWidgetHandle(), 0, 0) + self._layout.addWidget(self._histoVPlot.getWidgetHandle(), 0, 1) + self._layout.addWidget(self._histoHPlot.getWidgetHandle(), 1, 0) + self._layout.addWidget(self._radarView, 1, 1) + + self._layout.setColumnMinimumWidth(0, self.IMAGE_MIN_SIZE) + self._layout.setColumnStretch(0, 1) + self._layout.setColumnMinimumWidth(1, self.HISTOGRAMS_HEIGHT) + self._layout.setColumnStretch(1, 0) + + self._layout.setRowMinimumHeight(0, self.IMAGE_MIN_SIZE) + self._layout.setRowStretch(0, 1) + self._layout.setRowMinimumHeight(1, self.HISTOGRAMS_HEIGHT) + self._layout.setRowStretch(1, 0) + + self._layout.setSpacing(0) + self._layout.setContentsMargins(0, 0, 0, 0) + + centralWidget = qt.QWidget() + centralWidget.setLayout(self._layout) + self.setCentralWidget(centralWidget) + + def _dirtyCache(self): + self._cache = None + + def _updateHistograms(self): + """Update histograms content using current active image.""" + activeImage = self.getActiveImage() + if activeImage is not None: + wasUpdatingLimits = self._updatingLimits + self._updatingLimits = True + + data = activeImage.getData(copy=False) + origin = activeImage.getOrigin() + scale = activeImage.getScale() + height, width = data.shape + + xMin, xMax = self.getGraphXLimits() + yMin, yMax = self.getGraphYLimits() + + # Convert plot area limits to image coordinates + # and work in image coordinates (i.e., in pixels) + xMin = int((xMin - origin[0]) / scale[0]) + xMax = int((xMax - origin[0]) / scale[0]) + yMin = int((yMin - origin[1]) / scale[1]) + yMax = int((yMax - origin[1]) / scale[1]) + + if (xMin < width and xMax >= 0 and + yMin < height and yMax >= 0): + # The image is at least partly in the plot area + # Get the visible bounds in image coords (i.e., in pixels) + subsetXMin = 0 if xMin < 0 else xMin + subsetXMax = (width if xMax >= width else xMax) + 1 + subsetYMin = 0 if yMin < 0 else yMin + subsetYMax = (height if yMax >= height else yMax) + 1 + + if (self._cache is None or + subsetXMin != self._cache['dataXMin'] or + subsetXMax != self._cache['dataXMax'] or + subsetYMin != self._cache['dataYMin'] or + subsetYMax != self._cache['dataYMax']): + # The visible area of data has changed, update histograms + + # Rebuild histograms for visible area + visibleData = data[subsetYMin:subsetYMax, + subsetXMin:subsetXMax] + histoHVisibleData = numpy.sum(visibleData, axis=0) + histoVVisibleData = numpy.sum(visibleData, axis=1) + + self._cache = { + 'dataXMin': subsetXMin, + 'dataXMax': subsetXMax, + 'dataYMin': subsetYMin, + 'dataYMax': subsetYMax, + + 'histoH': histoHVisibleData, + 'histoHMin': numpy.min(histoHVisibleData), + 'histoHMax': numpy.max(histoHVisibleData), + + 'histoV': histoVVisibleData, + 'histoVMin': numpy.min(histoVVisibleData), + 'histoVMax': numpy.max(histoVVisibleData) + } + + # Convert to histogram curve and update plots + # Taking into account origin and scale + coords = numpy.arange(2 * histoHVisibleData.size) + xCoords = (coords + 1) // 2 + subsetXMin + xCoords = origin[0] + scale[0] * xCoords + xData = numpy.take(histoHVisibleData, coords // 2) + self._histoHPlot.addCurve(xCoords, xData, + xlabel='', ylabel='', + replace=False, + color=self.HISTOGRAMS_COLOR, + linestyle='-', + selectable=False) + vMin = self._cache['histoHMin'] + vMax = self._cache['histoHMax'] + vOffset = 0.1 * (vMax - vMin) + if vOffset == 0.: + vOffset = 1. + self._histoHPlot.setGraphYLimits(vMin - vOffset, + vMax + vOffset) + + coords = numpy.arange(2 * histoVVisibleData.size) + yCoords = (coords + 1) // 2 + subsetYMin + yCoords = origin[1] + scale[1] * yCoords + yData = numpy.take(histoVVisibleData, coords // 2) + self._histoVPlot.addCurve(yData, yCoords, + xlabel='', ylabel='', + replace=False, + color=self.HISTOGRAMS_COLOR, + linestyle='-', + selectable=False) + vMin = self._cache['histoVMin'] + vMax = self._cache['histoVMax'] + vOffset = 0.1 * (vMax - vMin) + if vOffset == 0.: + vOffset = 1. + self._histoVPlot.setGraphXLimits(vMin - vOffset, + vMax + vOffset) + else: + self._dirtyCache() + self._histoHPlot.remove(kind='curve') + self._histoVPlot.remove(kind='curve') + + self._updatingLimits = wasUpdatingLimits + + def _updateRadarView(self): + """Update radar view visible area. + + Takes care of y coordinate conversion. + """ + xMin, xMax = self.getGraphXLimits() + yMin, yMax = self.getGraphYLimits() + self._radarView.setVisibleRect(xMin, yMin, xMax - xMin, yMax - yMin) + + # Plots event listeners + + def _imagePlotCB(self, eventDict): + """Callback for imageView plot events.""" + if eventDict['event'] == 'mouseMoved': + activeImage = self.getActiveImage() + if activeImage is not None: + data = activeImage.getData(copy=False) + height, width = data.shape + + # Get corresponding coordinate in image + origin = activeImage.getOrigin() + scale = activeImage.getScale() + if (eventDict['x'] >= origin[0] and + eventDict['y'] >= origin[1]): + x = int((eventDict['x'] - origin[0]) / scale[0]) + y = int((eventDict['y'] - origin[1]) / scale[1]) + + if x >= 0 and x < width and y >= 0 and y < height: + self.valueChanged.emit(float(x), float(y), + data[y][x]) + + elif eventDict['event'] == 'limitsChanged': + # Do not handle histograms limitsChanged while + # updating their limits from here. + self._updatingLimits = True + + # Refresh histograms + self._updateHistograms() + + # could use eventDict['xdata'], eventDict['ydata'] instead + xMin, xMax = self.getGraphXLimits() + yMin, yMax = self.getGraphYLimits() + + # Set horizontal histo limits + self._histoHPlot.setGraphXLimits(xMin, xMax) + + # Set vertical histo limits + self._histoVPlot.setGraphYLimits(yMin, yMax) + + self._updateRadarView() + + self._updatingLimits = False + + def _histoHPlotCB(self, eventDict): + """Callback for horizontal histogram plot events.""" + if eventDict['event'] == 'mouseMoved': + if self._cache is not None: + activeImage = self.getActiveImage() + if activeImage is not None: + xOrigin = activeImage.getOrigin()[0] + xScale = activeImage.getScale()[0] + + minValue = xOrigin + xScale * self._cache['dataXMin'] + + if eventDict['x'] >= minValue: + data = self._cache['histoH'] + column = int((eventDict['x'] - minValue) / xScale) + if column >= 0 and column < data.shape[0]: + self.valueChanged.emit( + float('nan'), + float(column + self._cache['dataXMin']), + data[column]) + + elif eventDict['event'] == 'limitsChanged': + if (not self._updatingLimits and + eventDict['xdata'] != self.getGraphXLimits()): + xMin, xMax = eventDict['xdata'] + self.setGraphXLimits(xMin, xMax) + + def _histoVPlotCB(self, eventDict): + """Callback for vertical histogram plot events.""" + if eventDict['event'] == 'mouseMoved': + if self._cache is not None: + activeImage = self.getActiveImage() + if activeImage is not None: + yOrigin = activeImage.getOrigin()[1] + yScale = activeImage.getScale()[1] + + minValue = yOrigin + yScale * self._cache['dataYMin'] + + if eventDict['y'] >= minValue: + data = self._cache['histoV'] + row = int((eventDict['y'] - minValue) / yScale) + if row >= 0 and row < data.shape[0]: + self.valueChanged.emit( + float(row + self._cache['dataYMin']), + float('nan'), + data[row]) + + elif eventDict['event'] == 'limitsChanged': + if (not self._updatingLimits and + eventDict['ydata'] != self.getGraphYLimits()): + yMin, yMax = eventDict['ydata'] + self.setGraphYLimits(yMin, yMax) + + def _radarViewCB(self, left, top, width, height): + """Slot for radar view visible rectangle changes.""" + if not self._updatingLimits: + # Takes care of Y axis conversion + self.setLimits(left, left + width, top, top + height) + + def _updateYAxisInverted(self, inverted=None): + """Sync image, vertical histogram and radar view axis orientation.""" + if inverted is None: + # Do not perform this when called from plot signal + inverted = self.isYAxisInverted() + + self._histoVPlot.setYAxisInverted(inverted) + + # Use scale to invert radarView + # RadarView default Y direction is from top to bottom + # As opposed to Plot. So invert RadarView when Plot is NOT inverted. + self._radarView.resetTransform() + if not inverted: + self._radarView.scale(1., -1.) + self._updateRadarView() + + self._radarView.update() + + def _activeImageChangedSlot(self, previous, legend): + """Handle Plot active image change. + + Resets side histograms cache + """ + self._dirtyCache() + self._updateHistograms() + + def getHistogram(self, axis): + """Return the histogram and corresponding row or column extent. + + The returned value when an histogram is available is a dict with keys: + + - 'data': numpy array of the histogram values. + - 'extent': (start, end) row or column index. + end index is not included in the histogram. + + :param str axis: 'x' for horizontal, 'y' for vertical + :return: The histogram and its extent as a dict or None. + :rtype: dict + """ + assert axis in ('x', 'y') + if self._cache is None: + return None + else: + if axis == 'x': + return dict( + data=numpy.array(self._cache['histoH'], copy=True), + extent=(self._cache['dataXMin'], self._cache['dataXMax'])) + else: + return dict( + data=numpy.array(self._cache['histoV'], copy=True), + extent=(self._cache['dataYMin'], self._cache['dataYMax'])) + + def radarView(self): + """Get the lower right radarView widget.""" + return self._radarView + + def setRadarView(self, radarView): + """Change the lower right radarView widget. + + :param RadarView radarView: Widget subclassing RadarView to replace + the lower right corner widget. + """ + self._radarView.visibleRectDragged.disconnect(self._radarViewCB) + self._radarView = radarView + self._radarView.visibleRectDragged.connect(self._radarViewCB) + self._layout.addWidget(self._radarView, 1, 1) + + self._updateYAxisInverted() + + # High-level API + + def getColormap(self): + """Get the default colormap description. + + :return: A description of the current colormap. + See :meth:`setColormap` for details. + :rtype: dict + """ + return self.getDefaultColormap() + + def setColormap(self, colormap=None, normalization=None, + autoscale=None, vmin=None, vmax=None, colors=None): + """Set the default colormap and update active image. + + Parameters that are not provided are taken from the current colormap. + + The colormap parameter can also be a dict with the following keys: + + - *name*: string. The colormap to use: + 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'. + - *normalization*: string. The mapping to use for the colormap: + either 'linear' or 'log'. + - *autoscale*: bool. Whether to use autoscale (True) + or range provided by keys 'vmin' and 'vmax' (False). + - *vmin*: float. The minimum value of the range to use if 'autoscale' + is False. + - *vmax*: float. The maximum value of the range to use if 'autoscale' + is False. + - *colors*: optional. Nx3 or Nx4 array of float in [0, 1] or uint8. + List of RGB or RGBA colors to use (only if name is None) + + :param colormap: Name of the colormap in + 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'. + Or the description of the colormap as a dict. + :type colormap: dict or str. + :param str normalization: Colormap mapping: 'linear' or 'log'. + :param bool autoscale: Whether to use autoscale (True) + or [vmin, vmax] range (False). + :param float vmin: The minimum value of the range to use if + 'autoscale' is False. + :param float vmax: The maximum value of the range to use if + 'autoscale' is False. + :param numpy.ndarray colors: Only used if name is None. + Custom colormap colors as Nx3 or Nx4 RGB or RGBA arrays + """ + cmapDict = self.getDefaultColormap() + + if isinstance(colormap, dict): + # Support colormap parameter as a dict + assert normalization is None + assert autoscale is None + assert vmin is None + assert vmax is None + assert colors is None + for key, value in colormap.items(): + cmapDict[key] = value + + else: + if colormap is not None: + cmapDict['name'] = colormap + if normalization is not None: + cmapDict['normalization'] = normalization + if autoscale is not None: + cmapDict['autoscale'] = autoscale + if vmin is not None: + cmapDict['vmin'] = vmin + if vmax is not None: + cmapDict['vmax'] = vmax + if colors is not None: + cmapDict['colors'] = colors + + cursorColor = cursorColorForColormap(cmapDict['name']) + self.setInteractiveMode('zoom', color=cursorColor) + + self.setDefaultColormap(cmapDict) + + # Update active image colormap + activeImage = self.getActiveImage() + if isinstance(activeImage, items.ColormapMixIn): + activeImage.setColormap(self.getColormap()) + + def setImage(self, image, origin=(0, 0), scale=(1., 1.), + copy=True, reset=True): + """Set the image to display. + + :param image: A 2D array representing the image or None to empty plot. + :type image: numpy.ndarray-like with 2 dimensions or None. + :param origin: The (x, y) position of the origin of the image. + Default: (0, 0). + The origin is the lower left corner of the image when + the Y axis is not inverted. + :type origin: Tuple of 2 floats: (origin x, origin y). + :param scale: The scale factor to apply to the image on X and Y axes. + Default: (1, 1). + It is the size of a pixel in the coordinates of the axes. + Scales must be positive numbers. + :type scale: Tuple of 2 floats: (scale x, scale y). + :param bool copy: Whether to copy image data (default) or not. + :param bool reset: Whether to reset zoom and ROI (default) or not. + """ + self._dirtyCache() + + assert len(origin) == 2 + assert len(scale) == 2 + assert scale[0] > 0 + assert scale[1] > 0 + + if image is None: + self.remove(self._imageLegend, kind='image') + return + + data = numpy.array(image, order='C', copy=copy) + assert data.size != 0 + assert len(data.shape) == 2 + height, width = data.shape + + self.addImage(data, + legend=self._imageLegend, + origin=origin, scale=scale, + colormap=self.getColormap(), + replace=False) + self.setActiveImage(self._imageLegend) + self._updateHistograms() + + self._radarView.setDataRect(origin[0], + origin[1], + width * scale[0], + height * scale[1]) + + if reset: + self.resetZoom() + + +# ImageViewMainWindow ######################################################### + +class ImageViewMainWindow(ImageView): + """:class:`ImageView` with additional toolbars + + Adds extra toolbar and a status bar to :class:`ImageView`. + """ + def __init__(self, parent=None, backend=None): + self._dataInfo = None + super(ImageViewMainWindow, self).__init__(parent, backend) + self.setWindowFlags(qt.Qt.Window) + + self.setGraphXLabel('X') + self.setGraphYLabel('Y') + self.setGraphTitle('Image') + + # Add toolbars and status bar + self.addToolBar(qt.Qt.BottomToolBarArea, LimitsToolBar(plot=self)) + + self.statusBar() + + menu = self.menuBar().addMenu('File') + menu.addAction(self.saveAction) + menu.addAction(self.printAction) + menu.addSeparator() + action = menu.addAction('Quit') + action.triggered[bool].connect(qt.QApplication.instance().quit) + + menu = self.menuBar().addMenu('Edit') + menu.addAction(self.copyAction) + menu.addSeparator() + menu.addAction(self.resetZoomAction) + menu.addAction(self.colormapAction) + menu.addAction(PlotActions.KeepAspectRatioAction(self, self)) + menu.addAction(PlotActions.YAxisInvertedAction(self, self)) + + menu = self.menuBar().addMenu('Profile') + menu.addAction(self.profile.browseAction) + menu.addAction(self.profile.hLineAction) + menu.addAction(self.profile.vLineAction) + menu.addAction(self.profile.lineAction) + menu.addAction(self.profile.clearAction) + + # Connect to ImageView's signal + self.valueChanged.connect(self._statusBarSlot) + + def _statusBarSlot(self, row, column, value): + """Update status bar with coordinates/value from plots.""" + if numpy.isnan(row): + msg = 'Column: %d, Sum: %g' % (int(column), value) + elif numpy.isnan(column): + msg = 'Row: %d, Sum: %g' % (int(row), value) + else: + msg = 'Position: (%d, %d), Value: %g' % (int(row), int(column), + value) + if self._dataInfo is not None: + msg = self._dataInfo + ', ' + msg + + self.statusBar().showMessage(msg) + + def setImage(self, image, *args, **kwargs): + """Set the displayed image. + + See :meth:`ImageView.setImage` for details. + """ + if hasattr(image, 'dtype') and hasattr(image, 'shape'): + assert len(image.shape) == 2 + height, width = image.shape + self._dataInfo = 'Data: %dx%d (%s)' % (width, height, + str(image.dtype)) + self.statusBar().showMessage(self._dataInfo) + else: + self._dataInfo = None + + # Set the new image in ImageView widget + super(ImageViewMainWindow, self).setImage(image, *args, **kwargs) + self.setStatusBar(None) diff --git a/silx/gui/plot/Interaction.py b/silx/gui/plot/Interaction.py new file mode 100644 index 0000000..f09b9bc --- /dev/null +++ b/silx/gui/plot/Interaction.py @@ -0,0 +1,300 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides an implementation of state machines for interaction. + +Sample code of a state machine with two states ('idle' and 'active') +with transitions on left button press/release: + +.. code-block:: python + + from silx.gui.plot.Interaction import * + + class SampleStateMachine(StateMachine): + + class Idle(State): + def onPress(self, x, y, btn): + if btn == LEFT_BTN: + self.goto('active') + + class Active(State): + def enterState(self): + print('Enabled') # Handle enter active state here + + def leaveState(self): + print('Disabled') # Handle leave active state here + + def onRelease(self, x, y, btn): + if btn == LEFT_BTN: + self.goto('idle') + + def __init__(self): + # State machine has 2 states + states = { + 'idle': SampleStateMachine.Idle, + 'active': SampleStateMachine.Active + } + super(TwoStates, self).__init__(states, 'idle') + # idle is the initial state + + stateMachine = SampleStateMachine() + + # Triggers a transition to the Active state: + stateMachine.handleEvent('press', 0, 0, LEFT_BTN) + + # Triggers a transition to the Idle state: + stateMachine.handleEvent('release', 0, 0, LEFT_BTN) + +See :class:`ClickOrDrag` for another example of a state machine. + +See `Renaud Blanch, Michel Beaudouin-Lafon. +Programming Rich Interactions using the Hierarchical State Machine Toolkit. +In Proceedings of AVI 2006. p 51-58. +<http://iihm.imag.fr/en/publication/BB06a/>`_ +for a discussion of using (hierarchical) state machines for interaction. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "18/02/2016" + + +import weakref + + +# state machine ############################################################### + +class State(object): + """Base class for the states of a state machine. + + This class is meant to be subclassed. + """ + + def __init__(self, machine): + """State instances should be created by the :class:`StateMachine`. + + They are not intended to be used outside this context. + + :param machine: The state machine instance this state belongs to. + :type machine: StateMachine + """ + self._machineRef = weakref.ref(machine) # Prevent cyclic reference + + @property + def machine(self): + """The state machine this state belongs to. + + Useful to access data or methods that are shared across states. + """ + machine = self._machineRef() + if machine is not None: + return machine + else: + raise RuntimeError("Associated StateMachine is not valid") + + def goto(self, state, *args, **kwargs): + """Performs a transition to a new state. + + Extra arguments are passed to the :meth:`enterState` method of the + new state. + + :param str state: The name of the state to go to. + """ + self.machine._goto(state, *args, **kwargs) + + def enterState(self, *args, **kwargs): + """Called when the state machine enters this state. + + Arguments are those provided to the :meth:`goto` method that + triggered the transition to this state. + """ + pass + + def leaveState(self): + """Called when the state machine leaves this state + (i.e., when :meth:`goto` is called). + """ + pass + + +class StateMachine(object): + """State machine controller. + + This is the entry point of a state machine. + It is in charge of dispatching received event and handling the + current active state. + """ + + def __init__(self, states, initState, *args, **kwargs): + """Create a state machine controller with an initial state. + + Extra arguments are passed to the :meth:`enterState` method + of the initState. + + :param states: All states of the state machine + :type states: dict of: {str name: State subclass} + :param str initState: Key of the initial state in states + """ + self.states = states + + self.state = self.states[initState](self) + self.state.enterState(*args, **kwargs) + + def _goto(self, state, *args, **kwargs): + self.state.leaveState() + self.state = self.states[state](self) + self.state.enterState(*args, **kwargs) + + def handleEvent(self, eventName, *args, **kwargs): + """Process an event with the state machine. + + This method looks up for an event handler in the current state + and then in the :class:`StateMachine` instance. + Handler are looked up as 'onEventName' method. + If a handler is found, it is called with the provided extra + arguments, and this method returns the return value of the + handler. + If no handler is found, this method returns None. + + :param str eventName: Name of the event to handle + :returns: The return value of the handler or None + """ + handlerName = 'on' + eventName[0].upper() + eventName[1:] + try: + handler = getattr(self.state, handlerName) + except AttributeError: + try: + handler = getattr(self, handlerName) + except AttributeError: + handler = None + if handler is not None: + return handler(*args, **kwargs) + + +# clickOrDrag ################################################################# + +LEFT_BTN = 'left' +"""Left mouse button.""" + +RIGHT_BTN = 'right' +"""Right mouse button.""" + +MIDDLE_BTN = 'middle' +"""Middle mouse button.""" + + +class ClickOrDrag(StateMachine): + """State machine for left and right click and left drag interaction. + + It is intended to be used through subclassing by overriding + :meth:`click`, :meth:`beginDrag`, :meth:`drag` and :meth:`endDrag`. + """ + + DRAG_THRESHOLD_SQUARE_DIST = 5 ** 2 + + class Idle(State): + def onPress(self, x, y, btn): + if btn == LEFT_BTN: + self.goto('clickOrDrag', x, y) + return True + elif btn == RIGHT_BTN: + self.goto('rightClick', x, y) + return True + + class RightClick(State): + def onMove(self, x, y): + self.goto('idle') + + def onRelease(self, x, y, btn): + if btn == RIGHT_BTN: + self.machine.click(x, y, btn) + self.goto('idle') + + class ClickOrDrag(State): + def enterState(self, x, y): + self.initPos = x, y + + def onMove(self, x, y): + dx2 = (x - self.initPos[0]) ** 2 + dy2 = (y - self.initPos[1]) ** 2 + if (dx2 + dy2) >= self.machine.DRAG_THRESHOLD_SQUARE_DIST: + self.goto('drag', self.initPos, (x, y)) + + def onRelease(self, x, y, btn): + if btn == LEFT_BTN: + self.machine.click(x, y, btn) + self.goto('idle') + + class Drag(State): + def enterState(self, initPos, curPos): + self.initPos = initPos + self.machine.beginDrag(*initPos) + self.machine.drag(*curPos) + + def onMove(self, x, y): + self.machine.drag(x, y) + + def onRelease(self, x, y, btn): + if btn == LEFT_BTN: + self.machine.endDrag(self.initPos, (x, y)) + self.goto('idle') + + def __init__(self): + states = { + 'idle': ClickOrDrag.Idle, + 'rightClick': ClickOrDrag.RightClick, + 'clickOrDrag': ClickOrDrag.ClickOrDrag, + 'drag': ClickOrDrag.Drag + } + super(ClickOrDrag, self).__init__(states, 'idle') + + def click(self, x, y, btn): + """Called upon a left or right button click. + + To override in a subclass. + """ + pass + + def beginDrag(self, x, y): + """Called at the beginning of a drag gesture with left button + pressed. + + To override in a subclass. + """ + pass + + def drag(self, x, y): + """Called on mouse moved during a drag gesture. + + To override in a subclass. + """ + pass + + def endDrag(self, startPoint, endPoint): + """Called at the end of a drag gesture when the left button is + released. + + To override in a subclass. + """ + pass diff --git a/silx/gui/plot/LegendSelector.py b/silx/gui/plot/LegendSelector.py new file mode 100644 index 0000000..3af9050 --- /dev/null +++ b/silx/gui/plot/LegendSelector.py @@ -0,0 +1,1087 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Widget displaying curves legends and allowing to operate on curves. + +This widget is meant to work with :class:`PlotWindow`. +""" + +__authors__ = ["V.A. Sole", "T. Rueter", "T. Vincent"] +__license__ = "MIT" +__data__ = "28/04/2016" + + +import logging +import weakref + +from .. import qt + + +_logger = logging.getLogger(__name__) + +# Build all symbols +# Courtesy of the pyqtgraph project +Symbols = dict([(name, qt.QPainterPath()) + for name in ['o', 's', 't', 'd', '+', 'x', '.', ',']]) +Symbols['o'].addEllipse(qt.QRectF(.1, .1, .8, .8)) +Symbols['.'].addEllipse(qt.QRectF(.3, .3, .4, .4)) +Symbols[','].addEllipse(qt.QRectF(.4, .4, .2, .2)) +Symbols['s'].addRect(qt.QRectF(.1, .1, .8, .8)) + +coords = { + 't': [(0.5, 0.), (.1, .8), (.9, .8)], + 'd': [(0.1, 0.5), (0.5, 0.), (0.9, 0.5), (0.5, 1.)], + '+': [(0.0, 0.40), (0.40, 0.40), (0.40, 0.), (0.60, 0.), + (0.60, 0.40), (1., 0.40), (1., 0.60), (0.60, 0.60), + (0.60, 1.), (0.40, 1.), (0.40, 0.60), (0., 0.60)], + 'x': [(0.0, 0.40), (0.40, 0.40), (0.40, 0.), (0.60, 0.), + (0.60, 0.40), (1., 0.40), (1., 0.60), (0.60, 0.60), + (0.60, 1.), (0.40, 1.), (0.40, 0.60), (0., 0.60)] +} +for s, c in coords.items(): + Symbols[s].moveTo(*c[0]) + for x, y in c[1:]: + Symbols[s].lineTo(x, y) + Symbols[s].closeSubpath() +tr = qt.QTransform() +tr.rotate(45) +Symbols['x'].translate(qt.QPointF(-0.5, -0.5)) +Symbols['x'] = tr.map(Symbols['x']) +Symbols['x'].translate(qt.QPointF(0.5, 0.5)) + +NoSymbols = (None, 'None', 'none', '', ' ') +"""List of values resulting in no symbol being displayed for a curve""" + + +LineStyles = { + None: qt.Qt.NoPen, + 'None': qt.Qt.NoPen, + 'none': qt.Qt.NoPen, + '': qt.Qt.NoPen, + ' ': qt.Qt.NoPen, + '-': qt.Qt.SolidLine, + '--': qt.Qt.DashLine, + ':': qt.Qt.DotLine, + '-.': qt.Qt.DashDotLine +} +"""Conversion from matplotlib-like linestyle to Qt""" + +NoLineStyle = (None, 'None', 'none', '', ' ') +"""List of style values resulting in no line being displayed for a curve""" + + +class LegendIcon(qt.QWidget): + """Object displaying a curve linestyle and symbol.""" + + def __init__(self, parent=None): + super(LegendIcon, self).__init__(parent) + + # Visibilities + self.showLine = True + self.showSymbol = True + + # Line attributes + self.lineStyle = qt.Qt.NoPen + self.lineWidth = 1. + self.lineColor = qt.Qt.green + + self.symbol = '' + # Symbol attributes + self.symbolStyle = qt.Qt.SolidPattern + self.symbolColor = qt.Qt.green + self.symbolOutlineBrush = qt.QBrush(qt.Qt.white) + + # Control widget size: sizeHint "is the only acceptable + # alternative, so the widget can never grow or shrink" + # (c.f. Qt Doc, enum QSizePolicy::Policy) + self.setSizePolicy(qt.QSizePolicy.Fixed, + qt.QSizePolicy.Fixed) + + def sizeHint(self): + return qt.QSize(50, 15) + + # Modify Symbol + def setSymbol(self, symbol): + symbol = str(symbol) + if symbol not in NoSymbols: + if symbol not in Symbols: + raise ValueError("Unknown symbol: <%s>" % symbol) + self.symbol = symbol + # self.update() after set...? + # Does not seem necessary + + def setSymbolColor(self, color): + """ + :param color: determines the symbol color + :type style: qt.QColor + """ + self.symbolColor = qt.QColor(color) + + # Modify Line + + def setLineColor(self, color): + self.lineColor = qt.QColor(color) + + def setLineWidth(self, width): + self.lineWidth = float(width) + + def setLineStyle(self, style): + """Set the linestyle. + + Possible line styles: + + - '', ' ', 'None': No line + - '-': solid + - '--': dashed + - ':': dotted + - '-.': dash and dot + + :param str style: The linestyle to use + """ + if style not in LineStyles: + raise ValueError('Unknown style: %s', style) + self.lineStyle = LineStyles[style] + + # Paint + + def paintEvent(self, event): + """ + :param event: event + :type event: QPaintEvent + """ + painter = qt.QPainter(self) + self.paint(painter, event.rect(), self.palette()) + + def paint(self, painter, rect, palette): + painter.save() + painter.setRenderHint(qt.QPainter.Antialiasing) + # Scale painter to the icon height + # current -> width = 2.5, height = 1.0 + scale = float(self.height()) + ratio = float(self.width()) / scale + painter.scale(scale, + scale) + symbolOffset = qt.QPointF(.5 * (ratio - 1.), 0.) + # Determine and scale offset + offset = qt.QPointF(float(rect.left()) / scale, float(rect.top()) / scale) + # Draw BG rectangle (for debugging) + # bottomRight = qt.QPointF( + # float(rect.right())/scale, + # float(rect.bottom())/scale) + # painter.fillRect(qt.QRectF(offset, bottomRight), + # qt.QBrush(qt.Qt.green)) + llist = [] + if self.showLine: + linePath = qt.QPainterPath() + linePath.moveTo(0., 0.5) + linePath.lineTo(ratio, 0.5) + # linePath.lineTo(2.5, 0.5) + linePen = qt.QPen( + qt.QBrush(self.lineColor), + (self.lineWidth / self.height()), + self.lineStyle, + qt.Qt.FlatCap + ) + llist.append((linePath, + linePen, + qt.QBrush(self.lineColor))) + if (self.showSymbol and len(self.symbol) and + self.symbol not in NoSymbols): + # PITFALL ahead: Let this be a warning to others + # symbolPath = Symbols[self.symbol] + # Copy before translate! Dict is a mutable type + symbolPath = qt.QPainterPath(Symbols[self.symbol]) + symbolPath.translate(symbolOffset) + symbolBrush = qt.QBrush( + self.symbolColor, + self.symbolStyle + ) + symbolPen = qt.QPen( + self.symbolOutlineBrush, # Brush + 1. / self.height(), # Width + qt.Qt.SolidLine # Style + ) + llist.append((symbolPath, + symbolPen, + symbolBrush)) + # Draw + for path, pen, brush in llist: + path.translate(offset) + painter.setPen(pen) + painter.setBrush(brush) + painter.drawPath(path) + painter.restore() + + +class LegendModel(qt.QAbstractListModel): + """Data model of curve legends. + + It holds the information of the curve: + + - color + - line width + - line style + - visibility of the lines + - symbol + - visibility of the symbols + """ + iconColorRole = qt.Qt.UserRole + 0 + iconLineWidthRole = qt.Qt.UserRole + 1 + iconLineStyleRole = qt.Qt.UserRole + 2 + showLineRole = qt.Qt.UserRole + 3 + iconSymbolRole = qt.Qt.UserRole + 4 + showSymbolRole = qt.Qt.UserRole + 5 + + def __init__(self, legendList=None, parent=None): + super(LegendModel, self).__init__(parent) + if legendList is None: + legendList = [] + self.legendList = [] + self.insertLegendList(0, legendList) + + def __getitem__(self, idx): + if idx >= len(self.legendList): + raise IndexError('list index out of range') + return self.legendList[idx] + + def rowCount(self, modelIndex=None): + return len(self.legendList) + + def flags(self, index): + return (qt.Qt.ItemIsEditable | + qt.Qt.ItemIsEnabled | + qt.Qt.ItemIsSelectable) + + def data(self, modelIndex, role): + if modelIndex.isValid: + idx = modelIndex.row() + else: + return None + if idx >= len(self.legendList): + raise IndexError('list index out of range') + + item = self.legendList[idx] + if role == qt.Qt.DisplayRole: + # Data to be rendered in the form of text + legend = str(item[0]) + return legend + elif role == qt.Qt.SizeHintRole: + # size = qt.QSize(200,50) + _logger.warning('LegendModel -- size hint role not implemented') + return qt.QSize() + elif role == qt.Qt.TextAlignmentRole: + alignment = qt.Qt.AlignVCenter | qt.Qt.AlignLeft + return alignment + elif role == qt.Qt.BackgroundRole: + # Background color, must be QBrush + if idx % 2: + brush = qt.QBrush(qt.QColor(240, 240, 240)) + else: + brush = qt.QBrush(qt.Qt.white) + return brush + elif role == qt.Qt.ForegroundRole: + # ForegroundRole color, must be QBrush + brush = qt.QBrush(qt.Qt.blue) + return brush + elif role == qt.Qt.CheckStateRole: + return bool(item[2]) # item[2] == True + elif role == qt.Qt.ToolTipRole or role == qt.Qt.StatusTipRole: + return '' + elif role == self.iconColorRole: + return item[1]['color'] + elif role == self.iconLineWidthRole: + return item[1]['linewidth'] + elif role == self.iconLineStyleRole: + return item[1]['linestyle'] + elif role == self.iconSymbolRole: + return item[1]['symbol'] + elif role == self.showLineRole: + return item[3] + elif role == self.showSymbolRole: + return item[4] + else: + _logger.info('Unkown role requested: %s', str(role)) + return None + + def setData(self, modelIndex, value, role): + if modelIndex.isValid: + idx = modelIndex.row() + else: + return None + if idx >= len(self.legendList): + # raise IndexError('list index out of range') + _logger.warning( + 'setData -- List index out of range, idx: %d', idx) + return None + + item = self.legendList[idx] + try: + if role == qt.Qt.DisplayRole: + # Set legend + item[0] = str(value) + elif role == self.iconColorRole: + item[1]['color'] = qt.QColor(value) + elif role == self.iconLineWidthRole: + item[1]['linewidth'] = int(value) + elif role == self.iconLineStyleRole: + item[1]['linestyle'] = str(value) + elif role == self.iconSymbolRole: + item[1]['symbol'] = str(value) + elif role == qt.Qt.CheckStateRole: + item[2] = value + elif role == self.showLineRole: + item[3] = value + elif role == self.showSymbolRole: + item[4] = value + except ValueError: + _logger.warning('Conversion failed:\n\tvalue: %s\n\trole: %s', + str(value), str(role)) + # Can that be right? Read docs again.. + self.dataChanged.emit(modelIndex, modelIndex) + return True + + def insertLegendList(self, row, llist): + """ + :param int row: Determines after which row the items are inserted + :param llist: Carries the new legend information + :type llist: List + """ + modelIndex = self.createIndex(row, 0) + count = len(llist) + super(LegendModel, self).beginInsertRows(modelIndex, + row, + row + count) + head = self.legendList[0:row] + tail = self.legendList[row:] + new = [] + for (legend, icon) in llist: + linestyle = icon.get('linestyle', None) + if linestyle in NoLineStyle: + # Curve had no line, give it one and hide it + # So when toggle line, it will display a solid line + showLine = False + icon['linestyle'] = '-' + else: + showLine = True + + symbol = icon.get('symbol', None) + if symbol in NoSymbols: + # Curve had no symbol, give it one and hide it + # So when toggle symbol, it will display 'o' + showSymbol = False + icon['symbol'] = 'o' + else: + showSymbol = True + + selected = icon.get('selected', True) + item = [legend, + icon, + selected, + showLine, + showSymbol] + new.append(item) + self.legendList = head + new + tail + super(LegendModel, self).endInsertRows() + return True + + def insertRows(self, row, count, modelIndex=qt.QModelIndex()): + raise NotImplementedError('Use LegendModel.insertLegendList instead') + + def removeRow(self, row): + return self.removeRows(row, 1) + + def removeRows(self, row, count, modelIndex=qt.QModelIndex()): + length = len(self.legendList) + if length == 0: + # Nothing to do.. + return True + if row < 0 or row >= length: + raise IndexError('Index out of range -- ' + + 'idx: %d, len: %d' % (row, length)) + if count == 0: + return False + super(LegendModel, self).beginRemoveRows(modelIndex, + row, + row + count) + del(self.legendList[row:row + count]) + super(LegendModel, self).endRemoveRows() + return True + + def setEditor(self, event, editor): + """ + :param str event: String that identifies the editor + :param editor: Widget used to change data in the underlying model + :type editor: QWidget + """ + if event not in self.eventList: + raise ValueError('setEditor -- Event must be in %s' % + str(self.eventList)) + self.editorDict[event] = editor + + +class LegendListItemWidget(qt.QItemDelegate): + """Object displaying a single item (i.e., a row) in the list.""" + + # Notice: LegendListItem does NOT inherit + # from QObject, it cannot emit signals! + + def __init__(self, parent=None, itemType=0): + super(LegendListItemWidget, self).__init__(parent) + + # Dictionary to render checkboxes + self.cbDict = {} + self.labelDict = {} + self.iconDict = {} + + # Keep checkbox and legend to get sizeHint + self.checkbox = qt.QCheckBox() + self.legend = qt.QLabel() + self.icon = LegendIcon() + + # Context Menu and Editors + self.contextMenu = None + + def paint(self, painter, option, modelIndex): + """ + Here be docs.. + + :param QPainter painter: + :param QStyleOptionViewItem option: + :param QModelIndex modelIndex: + """ + painter.save() + rect = option.rect + + # Calculate the icon rectangle + iconSize = self.icon.sizeHint() + # Calculate icon position + x = rect.left() + 2 + y = rect.top() + int(.5 * (rect.height() - iconSize.height())) + iconRect = qt.QRect(qt.QPoint(x, y), iconSize) + + # Calculate label rectangle + legendSize = qt.QSize(rect.width() - iconSize.width() - 30, + rect.height()) + # Calculate label position + x = rect.left() + iconRect.width() + y = rect.top() + labelRect = qt.QRect(qt.QPoint(x, y), legendSize) + labelRect.translate(qt.QPoint(10, 0)) + + # Calculate the checkbox rectangle + x = rect.right() - 30 + y = rect.top() + chBoxRect = qt.QRect(qt.QPoint(x, y), rect.bottomRight()) + + # Remember the rectangles + idx = modelIndex.row() + self.cbDict[idx] = chBoxRect + self.iconDict[idx] = iconRect + self.labelDict[idx] = labelRect + + # Draw background first! + if option.state & qt.QStyle.State_MouseOver: + backgroundBrush = option.palette.highlight() + else: + backgroundBrush = modelIndex.data(qt.Qt.BackgroundRole) + painter.fillRect(rect, backgroundBrush) + + # Draw label + legendText = modelIndex.data(qt.Qt.DisplayRole) + textBrush = modelIndex.data(qt.Qt.ForegroundRole) + textAlign = modelIndex.data(qt.Qt.TextAlignmentRole) + painter.setBrush(textBrush) + painter.setFont(self.legend.font()) + painter.drawText(labelRect, textAlign, legendText) + + # Draw icon + iconColor = modelIndex.data(LegendModel.iconColorRole) + iconLineWidth = modelIndex.data(LegendModel.iconLineWidthRole) + iconLineStyle = modelIndex.data(LegendModel.iconLineStyleRole) + iconSymbol = modelIndex.data(LegendModel.iconSymbolRole) + icon = LegendIcon() + icon.resize(iconRect.size()) + icon.move(iconRect.topRight()) + icon.showSymbol = modelIndex.data(LegendModel.showSymbolRole) + icon.showLine = modelIndex.data(LegendModel.showLineRole) + icon.setSymbolColor(iconColor) + icon.setLineColor(iconColor) + icon.setLineWidth(iconLineWidth) + icon.setLineStyle(iconLineStyle) + icon.setSymbol(iconSymbol) + icon.symbolOutlineBrush = backgroundBrush + icon.paint(painter, iconRect, option.palette) + + # Draw the checkbox + if modelIndex.data(qt.Qt.CheckStateRole): + checkState = qt.Qt.Checked + else: + checkState = qt.Qt.Unchecked + + self.drawCheck( + painter, qt.QStyleOptionViewItem(), chBoxRect, checkState) + + painter.restore() + + def editorEvent(self, event, model, option, modelIndex): + # From the docs: + # Mouse events are sent to editorEvent() + # even if they don't start editing of the item. + if event.button() == qt.Qt.RightButton and self.contextMenu: + self.contextMenu.exec_(event.globalPos(), modelIndex) + return True + elif event.button() == qt.Qt.LeftButton: + # Check if checkbox was clicked + idx = modelIndex.row() + cbRect = self.cbDict[idx] + if cbRect.contains(event.pos()): + # Toggle checkbox + model.setData(modelIndex, + not modelIndex.data(qt.Qt.CheckStateRole), + qt.Qt.CheckStateRole) + event.ignore() + return True + else: + return super(LegendListItemWidget, self).editorEvent( + event, model, option, modelIndex) + + def createEditor(self, parent, option, idx): + _logger.info('### Editor request ###') + + def sizeHint(self, option, idx): + # return qt.QSize(68,24) + iconSize = self.icon.sizeHint() + legendSize = self.legend.sizeHint() + checkboxSize = self.checkbox.sizeHint() + height = max([iconSize.height(), + legendSize.height(), + checkboxSize.height()]) + 4 + width = iconSize.width() + legendSize.width() + checkboxSize.width() + return qt.QSize(width, height) + + +class LegendListView(qt.QListView): + """Widget displaying a list of curve legends, line style and symbol.""" + + sigLegendSignal = qt.Signal(object) + """Signal emitting a dict when an action is triggered by the user.""" + + __mouseClickedEvent = 'mouseClicked' + __checkBoxClickedEvent = 'checkBoxClicked' + __legendClickedEvent = 'legendClicked' + + def __init__(self, parent=None, model=None, contextMenu=None): + super(LegendListView, self).__init__(parent) + self.__lastButton = None + self.__lastClickPos = None + self.__lastModelIdx = None + # Set default delegate + self.setItemDelegate(LegendListItemWidget()) + # Set default editors + # self.setSizePolicy(qt.QSizePolicy.MinimumExpanding, + # qt.QSizePolicy.MinimumExpanding) + # Set edit triggers by hand using self.edit(QModelIndex) + # in mousePressEvent (better to control than signals) + self.setEditTriggers(qt.QAbstractItemView.NoEditTriggers) + + # Control layout + # self.setBatchSize(2) + # self.setLayoutMode(qt.QListView.Batched) + # self.setFlow(qt.QListView.LeftToRight) + + # Control selection + self.setSelectionMode(qt.QAbstractItemView.NoSelection) + + if model is None: + model = LegendModel() + self.setModel(model) + self.setContextMenu(contextMenu) + + def setLegendList(self, legendList, row=None): + self.clear() + if row is None: + row = 0 + model = self.model() + model.insertLegendList(row, legendList) + _logger.debug('LegendListView.setLegendList(legendList) finished') + + def clear(self): + model = self.model() + model.removeRows(0, model.rowCount()) + _logger.debug('LegendListView.clear() finished') + + def setContextMenu(self, contextMenu=None): + delegate = self.itemDelegate() + if isinstance(delegate, LegendListItemWidget) and self.model(): + if contextMenu is None: + delegate.contextMenu = LegendListContextMenu(self.model()) + delegate.contextMenu.sigContextMenu.connect( + self._contextMenuSlot) + else: + delegate.contextMenu = contextMenu + + def __getitem__(self, idx): + model = self.model() + try: + item = model[idx] + except ValueError: + item = None + return item + + def _contextMenuSlot(self, ddict): + self.sigLegendSignal.emit(ddict) + + def mousePressEvent(self, event): + self.__lastButton = event.button() + self.__lastPosition = event.pos() + super(LegendListView, self).mousePressEvent(event) + # call _handleMouseClick after editing was handled + # If right click (context menu) is aborted, no + # signal is emitted.. + self._handleMouseClick(self.indexAt(self.__lastPosition)) + + def mouseDoubleClickEvent(self, event): + self.__lastButton = event.button() + self.__lastPosition = event.pos() + super(LegendListView, self).mouseDoubleClickEvent(event) + # call _handleMouseClick after editing was handled + # If right click (context menu) is aborted, no + # signal is emitted.. + self._handleMouseClick(self.indexAt(self.__lastPosition)) + + def mouseMoveEvent(self, event): + # LegendListView.mouseMoveEvent is overwritten + # to suppress unwanted behavior in the delegate. + pass + + def mouseReleaseEvent(self, event): + # LegendListView.mouseReleaseEvent is overwritten + # to subpress unwanted behavior in the delegate. + pass + + def _handleMouseClick(self, modelIndex): + """ + Distinguish between mouse click on Legend + and mouse click on CheckBox by setting the + currentCheckState attribute in LegendListItem. + + Emits signal sigLegendSignal(ddict) + + :param QModelIndex modelIndex: index of the clicked item + """ + _logger.debug('self._handleMouseClick called') + if self.__lastButton not in [qt.Qt.LeftButton, + qt.Qt.RightButton]: + return + if not modelIndex.isValid(): + _logger.debug('_handleMouseClick -- Invalid QModelIndex') + return + # model = self.model() + idx = modelIndex.row() + + delegate = self.itemDelegate() + cbClicked = False + if isinstance(delegate, LegendListItemWidget): + for cbRect in delegate.cbDict.values(): + if cbRect.contains(self.__lastPosition): + cbClicked = True + break + + # TODO: Check for doubleclicks on legend/icon and spawn editors + + ddict = { + 'legend': str(modelIndex.data(qt.Qt.DisplayRole)), + 'icon': { + 'linewidth': str(modelIndex.data( + LegendModel.iconLineWidthRole)), + 'linestyle': str(modelIndex.data( + LegendModel.iconLineStyleRole)), + 'symbol': str(modelIndex.data(LegendModel.iconSymbolRole)) + }, + 'selected': modelIndex.data(qt.Qt.CheckStateRole), + 'type': str(modelIndex.data()) + } + if self.__lastButton == qt.Qt.RightButton: + _logger.debug('Right clicked') + ddict['button'] = "right" + ddict['event'] = self.__mouseClickedEvent + elif cbClicked: + _logger.debug('CheckBox clicked') + ddict['button'] = "left" + ddict['event'] = self.__checkBoxClickedEvent + else: + _logger.debug('Legend clicked') + ddict['button'] = "left" + ddict['event'] = self.__legendClickedEvent + _logger.debug(' idx: %d\n ddict: %s', idx, str(ddict)) + self.sigLegendSignal.emit(ddict) + + +class LegendListContextMenu(qt.QMenu): + """Contextual menu associated to items in a :class:`LegendListView`.""" + + sigContextMenu = qt.Signal(object) + """Signal emitting a dict upon contextual menu actions.""" + + def __init__(self, model): + super(LegendListContextMenu, self).__init__(parent=None) + self.model = model + + self.addAction('Set Active', self.setActiveAction) + self.addAction('Map to left', self.mapToLeftAction) + self.addAction('Map to right', self.mapToRightAction) + + self._pointsAction = self.addAction( + 'Points', self.togglePointsAction) + self._pointsAction.setCheckable(True) + + self._linesAction = self.addAction('Lines', self.toggleLinesAction) + self._linesAction.setCheckable(True) + + self.addAction('Remove curve', self.removeItemAction) + self.addAction('Rename curve', self.renameItemAction) + + def exec_(self, pos, idx): + self.__currentIdx = idx + + # Set checkable action state + modelIndex = self.currentIdx() + self._pointsAction.setChecked( + modelIndex.data(LegendModel.showSymbolRole)) + self._linesAction.setChecked( + modelIndex.data(LegendModel.showLineRole)) + + super(LegendListContextMenu, self).popup(pos) + + def currentIdx(self): + return self.__currentIdx + + def mapToLeftAction(self): + _logger.debug('LegendListContextMenu.mapToLeftAction called') + modelIndex = self.currentIdx() + legend = str(modelIndex.data(qt.Qt.DisplayRole)) + ddict = { + 'legend': legend, + 'label': legend, + 'selected': modelIndex.data(qt.Qt.CheckStateRole), + 'type': str(modelIndex.data()), + 'event': "mapToLeft" + } + self.sigContextMenu.emit(ddict) + + def mapToRightAction(self): + _logger.debug('LegendListContextMenu.mapToRightAction called') + modelIndex = self.currentIdx() + legend = str(modelIndex.data(qt.Qt.DisplayRole)) + ddict = { + 'legend': legend, + 'label': legend, + 'selected': modelIndex.data(qt.Qt.CheckStateRole), + 'type': str(modelIndex.data()), + 'event': "mapToRight" + } + self.sigContextMenu.emit(ddict) + + def removeItemAction(self): + _logger.debug('LegendListContextMenu.removeCurveAction called') + modelIndex = self.currentIdx() + legend = str(modelIndex.data(qt.Qt.DisplayRole)) + ddict = { + 'legend': legend, + 'label': legend, + 'selected': modelIndex.data(qt.Qt.CheckStateRole), + 'type': str(modelIndex.data()), + 'event': "removeCurve" + } + self.model.removeRow(modelIndex.row()) + self.sigContextMenu.emit(ddict) + + def renameItemAction(self): + _logger.debug('LegendListContextMenu.renameCurveAction called') + modelIndex = self.currentIdx() + legend = str(modelIndex.data(qt.Qt.DisplayRole)) + ddict = { + 'legend': legend, + 'label': legend, + 'selected': modelIndex.data(qt.Qt.CheckStateRole), + 'type': str(modelIndex.data()), + 'event': "renameCurve" + } + self.sigContextMenu.emit(ddict) + + def toggleLinesAction(self): + modelIndex = self.currentIdx() + legend = str(modelIndex.data(qt.Qt.DisplayRole)) + ddict = { + 'legend': legend, + 'label': legend, + 'selected': modelIndex.data(qt.Qt.CheckStateRole), + 'type': str(modelIndex.data()), + } + linestyle = modelIndex.data(LegendModel.iconLineStyleRole) + visible = not modelIndex.data(LegendModel.showLineRole) + _logger.debug('toggleLinesAction -- lines visible: %s', str(visible)) + ddict['event'] = "toggleLine" + ddict['line'] = visible + ddict['linestyle'] = linestyle if visible else '' + self.model.setData(modelIndex, visible, LegendModel.showLineRole) + self.sigContextMenu.emit(ddict) + + def togglePointsAction(self): + modelIndex = self.currentIdx() + legend = str(modelIndex.data(qt.Qt.DisplayRole)) + ddict = { + 'legend': legend, + 'label': legend, + 'selected': modelIndex.data(qt.Qt.CheckStateRole), + 'type': str(modelIndex.data()), + } + flag = modelIndex.data(LegendModel.showSymbolRole) + symbol = modelIndex.data(LegendModel.iconSymbolRole) + visible = not flag or symbol in NoSymbols + _logger.debug( + 'togglePointsAction -- Symbols visible: %s', str(visible)) + + ddict['event'] = "togglePoints" + ddict['points'] = visible + ddict['symbol'] = symbol if visible else '' + self.model.setData(modelIndex, visible, LegendModel.showSymbolRole) + self.sigContextMenu.emit(ddict) + + def setActiveAction(self): + modelIndex = self.currentIdx() + legend = str(modelIndex.data(qt.Qt.DisplayRole)) + _logger.debug('setActiveAction -- active curve: %s', legend) + ddict = { + 'legend': legend, + 'label': legend, + 'selected': modelIndex.data(qt.Qt.CheckStateRole), + 'type': str(modelIndex.data()), + 'event': "setActiveCurve", + } + self.sigContextMenu.emit(ddict) + + +class RenameCurveDialog(qt.QDialog): + """Dialog box to input the name of a curve.""" + + def __init__(self, parent=None, current="", curves=()): + super(RenameCurveDialog, self).__init__(parent) + self.setWindowTitle("Rename Curve %s" % current) + self.curves = curves + layout = qt.QVBoxLayout(self) + self.lineEdit = qt.QLineEdit(self) + self.lineEdit.setText(current) + self.hbox = qt.QWidget(self) + self.hboxLayout = qt.QHBoxLayout(self.hbox) + self.hboxLayout.addStretch(1) + self.okButton = qt.QPushButton(self.hbox) + self.okButton.setText('OK') + self.hboxLayout.addWidget(self.okButton) + self.cancelButton = qt.QPushButton(self.hbox) + self.cancelButton.setText('Cancel') + self.hboxLayout.addWidget(self.cancelButton) + self.hboxLayout.addStretch(1) + layout.addWidget(self.lineEdit) + layout.addWidget(self.hbox) + self.okButton.clicked.connect(self.preAccept) + self.cancelButton.clicked.connect(self.reject) + + def preAccept(self): + text = str(self.lineEdit.text()) + addedText = "" + if len(text): + if text not in self.curves: + self.accept() + return + else: + addedText = "Curve already exists." + text = "Invalid Curve Name" + msg = qt.QMessageBox(self) + msg.setIcon(qt.QMessageBox.Critical) + msg.setWindowTitle(text) + text += "\n%s" % addedText + msg.setText(text) + msg.exec_() + + def getText(self): + return str(self.lineEdit.text()) + + +class LegendsDockWidget(qt.QDockWidget): + """QDockWidget with a :class:`LegendSelector` connected to a PlotWindow. + + It makes the link between the LegendListView widget and the PlotWindow. + + :param parent: See :class:`QDockWidget` + :param plot: :class:`.PlotWindow` instance on which to operate + """ + + def __init__(self, parent=None, plot=None): + assert plot is not None + self._plotRef = weakref.ref(plot) + self._isConnected = False # True if widget connected to plot signals + + super(LegendsDockWidget, self).__init__("Legends", parent) + + self._legendWidget = LegendListView() + + self.layout().setContentsMargins(0, 0, 0, 0) + self.setWidget(self._legendWidget) + + self.visibilityChanged.connect( + self._visibilityChangedHandler) + + self._legendWidget.sigLegendSignal.connect(self._legendSignalHandler) + + @property + def plot(self): + """The :class:`.PlotWindow` this widget is attached to.""" + return self._plotRef() + + def renameCurve(self, oldLegend, newLegend): + """Change the name of a curve using remove and addCurve + + :param str oldLegend: The legend of the curve to be change + :param str newLegend: The new legend of the curve + """ + curve = self.plot.getCurve(oldLegend) + self.plot.remove(oldLegend, kind='curve') + self.plot.addCurve(curve.getXData(copy=False), + curve.getYData(copy=False), + legend=newLegend, + info=curve.getInfo(), + color=curve.getColor(), + symbol=curve.getSymbol(), + linewidth=curve.getLineWidth(), + linestyle=curve.getLineStyle(), + xlabel=curve.getXLabel(), + ylabel=curve.getYLabel(), + xerror=curve.getXErrorData(copy=False), + yerror=curve.getYErrorData(copy=False), + z=curve.getZValue(), + selectable=curve.isSelectable(), + fill=curve.isFill(), + resetzoom=False) + + def _legendSignalHandler(self, ddict): + """Handles events from the LegendListView signal""" + _logger.debug("Legend signal ddict = %s", str(ddict)) + + if ddict['event'] == "legendClicked": + if ddict['button'] == "left": + self.plot.setActiveCurve(ddict['legend']) + + elif ddict['event'] == "removeCurve": + self.plot.removeCurve(ddict['legend']) + + elif ddict['event'] == "renameCurve": + curveList = self.plot.getAllCurves(just_legend=True) + oldLegend = ddict['legend'] + dialog = RenameCurveDialog(self.plot, oldLegend, curveList) + ret = dialog.exec_() + if ret: + newLegend = dialog.getText() + self.renameCurve(oldLegend, newLegend) + + elif ddict['event'] == "setActiveCurve": + self.plot.setActiveCurve(ddict['legend']) + + elif ddict['event'] == "checkBoxClicked": + self.plot.hideCurve(ddict['legend'], not ddict['selected']) + + elif ddict['event'] in ["mapToRight", "mapToLeft"]: + legend = ddict['legend'] + curve = self.plot.getCurve(legend) + yaxis = 'right' if ddict['event'] == 'mapToRight' else 'left' + self.plot.addCurve(x=curve.getXData(copy=False), + y=curve.getYData(copy=False), + legend=curve.getLegend(), + info=curve.getInfo(), + yaxis=yaxis) + + elif ddict['event'] == "togglePoints": + legend = ddict['legend'] + curve = self.plot.getCurve(legend) + symbol = ddict['symbol'] if ddict['points'] else '' + self.plot.addCurve(x=curve.getXData(copy=False), + y=curve.getYData(copy=False), + legend=curve.getLegend(), + info=curve.getInfo(), + symbol=symbol) + + elif ddict['event'] == "toggleLine": + legend = ddict['legend'] + curve = self.plot.getCurve(legend) + linestyle = ddict['linestyle'] if ddict['line'] else '' + self.plot.addCurve(x=curve.getXData(copy=False), + y=curve.getYData(copy=False), + legend=curve.getLegend(), + info=curve.getInfo(), + linestyle=linestyle) + + else: + _logger.debug("unhandled event %s", str(ddict['event'])) + + def updateLegends(self, *args): + """Sync the LegendSelector widget displayed info with the plot. + """ + legendList = [] + for curve in self.plot.getAllCurves(withhidden=True): + legend = curve.getLegend() + # Use active color if curve is active + if legend == self.plot.getActiveCurve(just_legend=True): + color = qt.QColor(self.plot.getActiveCurveColor()) + else: + color = qt.QColor.fromRgbF(*curve.getColor()) + + curveInfo = { + 'color': color, + 'linewidth': curve.getLineWidth(), + 'linestyle': curve.getLineStyle(), + 'symbol': curve.getSymbol(), + 'selected': not self.plot.isCurveHidden(legend)} + legendList.append((legend, curveInfo)) + + self._legendWidget.setLegendList(legendList) + + def _visibilityChangedHandler(self, visible): + if visible: + self.updateLegends() + if not self._isConnected: + self.plot.sigContentChanged.connect(self.updateLegends) + self.plot.sigActiveCurveChanged.connect(self.updateLegends) + self._isConnected = True + else: + if self._isConnected: + self.plot.sigContentChanged.disconnect(self.updateLegends) + self.plot.sigActiveCurveChanged.disconnect(self.updateLegends) + self._isConnected = False + + def showEvent(self, event): + """Make sure this widget is raised when it is shown + (when it is first created as a tab in PlotWindow or when it is shown + again after hiding). + """ + self.raise_() diff --git a/silx/gui/plot/MPLColormap.py b/silx/gui/plot/MPLColormap.py new file mode 100644 index 0000000..49b11d7 --- /dev/null +++ b/silx/gui/plot/MPLColormap.py @@ -0,0 +1,1062 @@ +# New matplotlib colormaps by Nathaniel J. Smith, Stefan van der Walt, +# and (in the case of viridis) Eric Firing. +# +# This file and the colormaps in it are released under the CC0 license / +# public domain dedication. We would appreciate credit if you use or +# redistribute these colormaps, but do not impose any legal restrictions. +# +# To the extent possible under law, the persons who associated CC0 with +# mpl-colormaps have waived all copyright and related or neighboring rights +# to mpl-colormaps. +# +# You should have received a copy of the CC0 legalcode along with this +# work. If not, see <http://creativecommons.org/publicdomain/zero/1.0/>. +"""Matplotlib's new colormaps""" + + +from matplotlib.colors import ListedColormap + + +__all__ = ['magma', 'inferno', 'plasma', 'viridis'] + +_magma_data = [[0.001462, 0.000466, 0.013866], + [0.002258, 0.001295, 0.018331], + [0.003279, 0.002305, 0.023708], + [0.004512, 0.003490, 0.029965], + [0.005950, 0.004843, 0.037130], + [0.007588, 0.006356, 0.044973], + [0.009426, 0.008022, 0.052844], + [0.011465, 0.009828, 0.060750], + [0.013708, 0.011771, 0.068667], + [0.016156, 0.013840, 0.076603], + [0.018815, 0.016026, 0.084584], + [0.021692, 0.018320, 0.092610], + [0.024792, 0.020715, 0.100676], + [0.028123, 0.023201, 0.108787], + [0.031696, 0.025765, 0.116965], + [0.035520, 0.028397, 0.125209], + [0.039608, 0.031090, 0.133515], + [0.043830, 0.033830, 0.141886], + [0.048062, 0.036607, 0.150327], + [0.052320, 0.039407, 0.158841], + [0.056615, 0.042160, 0.167446], + [0.060949, 0.044794, 0.176129], + [0.065330, 0.047318, 0.184892], + [0.069764, 0.049726, 0.193735], + [0.074257, 0.052017, 0.202660], + [0.078815, 0.054184, 0.211667], + [0.083446, 0.056225, 0.220755], + [0.088155, 0.058133, 0.229922], + [0.092949, 0.059904, 0.239164], + [0.097833, 0.061531, 0.248477], + [0.102815, 0.063010, 0.257854], + [0.107899, 0.064335, 0.267289], + [0.113094, 0.065492, 0.276784], + [0.118405, 0.066479, 0.286321], + [0.123833, 0.067295, 0.295879], + [0.129380, 0.067935, 0.305443], + [0.135053, 0.068391, 0.315000], + [0.140858, 0.068654, 0.324538], + [0.146785, 0.068738, 0.334011], + [0.152839, 0.068637, 0.343404], + [0.159018, 0.068354, 0.352688], + [0.165308, 0.067911, 0.361816], + [0.171713, 0.067305, 0.370771], + [0.178212, 0.066576, 0.379497], + [0.184801, 0.065732, 0.387973], + [0.191460, 0.064818, 0.396152], + [0.198177, 0.063862, 0.404009], + [0.204935, 0.062907, 0.411514], + [0.211718, 0.061992, 0.418647], + [0.218512, 0.061158, 0.425392], + [0.225302, 0.060445, 0.431742], + [0.232077, 0.059889, 0.437695], + [0.238826, 0.059517, 0.443256], + [0.245543, 0.059352, 0.448436], + [0.252220, 0.059415, 0.453248], + [0.258857, 0.059706, 0.457710], + [0.265447, 0.060237, 0.461840], + [0.271994, 0.060994, 0.465660], + [0.278493, 0.061978, 0.469190], + [0.284951, 0.063168, 0.472451], + [0.291366, 0.064553, 0.475462], + [0.297740, 0.066117, 0.478243], + [0.304081, 0.067835, 0.480812], + [0.310382, 0.069702, 0.483186], + [0.316654, 0.071690, 0.485380], + [0.322899, 0.073782, 0.487408], + [0.329114, 0.075972, 0.489287], + [0.335308, 0.078236, 0.491024], + [0.341482, 0.080564, 0.492631], + [0.347636, 0.082946, 0.494121], + [0.353773, 0.085373, 0.495501], + [0.359898, 0.087831, 0.496778], + [0.366012, 0.090314, 0.497960], + [0.372116, 0.092816, 0.499053], + [0.378211, 0.095332, 0.500067], + [0.384299, 0.097855, 0.501002], + [0.390384, 0.100379, 0.501864], + [0.396467, 0.102902, 0.502658], + [0.402548, 0.105420, 0.503386], + [0.408629, 0.107930, 0.504052], + [0.414709, 0.110431, 0.504662], + [0.420791, 0.112920, 0.505215], + [0.426877, 0.115395, 0.505714], + [0.432967, 0.117855, 0.506160], + [0.439062, 0.120298, 0.506555], + [0.445163, 0.122724, 0.506901], + [0.451271, 0.125132, 0.507198], + [0.457386, 0.127522, 0.507448], + [0.463508, 0.129893, 0.507652], + [0.469640, 0.132245, 0.507809], + [0.475780, 0.134577, 0.507921], + [0.481929, 0.136891, 0.507989], + [0.488088, 0.139186, 0.508011], + [0.494258, 0.141462, 0.507988], + [0.500438, 0.143719, 0.507920], + [0.506629, 0.145958, 0.507806], + [0.512831, 0.148179, 0.507648], + [0.519045, 0.150383, 0.507443], + [0.525270, 0.152569, 0.507192], + [0.531507, 0.154739, 0.506895], + [0.537755, 0.156894, 0.506551], + [0.544015, 0.159033, 0.506159], + [0.550287, 0.161158, 0.505719], + [0.556571, 0.163269, 0.505230], + [0.562866, 0.165368, 0.504692], + [0.569172, 0.167454, 0.504105], + [0.575490, 0.169530, 0.503466], + [0.581819, 0.171596, 0.502777], + [0.588158, 0.173652, 0.502035], + [0.594508, 0.175701, 0.501241], + [0.600868, 0.177743, 0.500394], + [0.607238, 0.179779, 0.499492], + [0.613617, 0.181811, 0.498536], + [0.620005, 0.183840, 0.497524], + [0.626401, 0.185867, 0.496456], + [0.632805, 0.187893, 0.495332], + [0.639216, 0.189921, 0.494150], + [0.645633, 0.191952, 0.492910], + [0.652056, 0.193986, 0.491611], + [0.658483, 0.196027, 0.490253], + [0.664915, 0.198075, 0.488836], + [0.671349, 0.200133, 0.487358], + [0.677786, 0.202203, 0.485819], + [0.684224, 0.204286, 0.484219], + [0.690661, 0.206384, 0.482558], + [0.697098, 0.208501, 0.480835], + [0.703532, 0.210638, 0.479049], + [0.709962, 0.212797, 0.477201], + [0.716387, 0.214982, 0.475290], + [0.722805, 0.217194, 0.473316], + [0.729216, 0.219437, 0.471279], + [0.735616, 0.221713, 0.469180], + [0.742004, 0.224025, 0.467018], + [0.748378, 0.226377, 0.464794], + [0.754737, 0.228772, 0.462509], + [0.761077, 0.231214, 0.460162], + [0.767398, 0.233705, 0.457755], + [0.773695, 0.236249, 0.455289], + [0.779968, 0.238851, 0.452765], + [0.786212, 0.241514, 0.450184], + [0.792427, 0.244242, 0.447543], + [0.798608, 0.247040, 0.444848], + [0.804752, 0.249911, 0.442102], + [0.810855, 0.252861, 0.439305], + [0.816914, 0.255895, 0.436461], + [0.822926, 0.259016, 0.433573], + [0.828886, 0.262229, 0.430644], + [0.834791, 0.265540, 0.427671], + [0.840636, 0.268953, 0.424666], + [0.846416, 0.272473, 0.421631], + [0.852126, 0.276106, 0.418573], + [0.857763, 0.279857, 0.415496], + [0.863320, 0.283729, 0.412403], + [0.868793, 0.287728, 0.409303], + [0.874176, 0.291859, 0.406205], + [0.879464, 0.296125, 0.403118], + [0.884651, 0.300530, 0.400047], + [0.889731, 0.305079, 0.397002], + [0.894700, 0.309773, 0.393995], + [0.899552, 0.314616, 0.391037], + [0.904281, 0.319610, 0.388137], + [0.908884, 0.324755, 0.385308], + [0.913354, 0.330052, 0.382563], + [0.917689, 0.335500, 0.379915], + [0.921884, 0.341098, 0.377376], + [0.925937, 0.346844, 0.374959], + [0.929845, 0.352734, 0.372677], + [0.933606, 0.358764, 0.370541], + [0.937221, 0.364929, 0.368567], + [0.940687, 0.371224, 0.366762], + [0.944006, 0.377643, 0.365136], + [0.947180, 0.384178, 0.363701], + [0.950210, 0.390820, 0.362468], + [0.953099, 0.397563, 0.361438], + [0.955849, 0.404400, 0.360619], + [0.958464, 0.411324, 0.360014], + [0.960949, 0.418323, 0.359630], + [0.963310, 0.425390, 0.359469], + [0.965549, 0.432519, 0.359529], + [0.967671, 0.439703, 0.359810], + [0.969680, 0.446936, 0.360311], + [0.971582, 0.454210, 0.361030], + [0.973381, 0.461520, 0.361965], + [0.975082, 0.468861, 0.363111], + [0.976690, 0.476226, 0.364466], + [0.978210, 0.483612, 0.366025], + [0.979645, 0.491014, 0.367783], + [0.981000, 0.498428, 0.369734], + [0.982279, 0.505851, 0.371874], + [0.983485, 0.513280, 0.374198], + [0.984622, 0.520713, 0.376698], + [0.985693, 0.528148, 0.379371], + [0.986700, 0.535582, 0.382210], + [0.987646, 0.543015, 0.385210], + [0.988533, 0.550446, 0.388365], + [0.989363, 0.557873, 0.391671], + [0.990138, 0.565296, 0.395122], + [0.990871, 0.572706, 0.398714], + [0.991558, 0.580107, 0.402441], + [0.992196, 0.587502, 0.406299], + [0.992785, 0.594891, 0.410283], + [0.993326, 0.602275, 0.414390], + [0.993834, 0.609644, 0.418613], + [0.994309, 0.616999, 0.422950], + [0.994738, 0.624350, 0.427397], + [0.995122, 0.631696, 0.431951], + [0.995480, 0.639027, 0.436607], + [0.995810, 0.646344, 0.441361], + [0.996096, 0.653659, 0.446213], + [0.996341, 0.660969, 0.451160], + [0.996580, 0.668256, 0.456192], + [0.996775, 0.675541, 0.461314], + [0.996925, 0.682828, 0.466526], + [0.997077, 0.690088, 0.471811], + [0.997186, 0.697349, 0.477182], + [0.997254, 0.704611, 0.482635], + [0.997325, 0.711848, 0.488154], + [0.997351, 0.719089, 0.493755], + [0.997351, 0.726324, 0.499428], + [0.997341, 0.733545, 0.505167], + [0.997285, 0.740772, 0.510983], + [0.997228, 0.747981, 0.516859], + [0.997138, 0.755190, 0.522806], + [0.997019, 0.762398, 0.528821], + [0.996898, 0.769591, 0.534892], + [0.996727, 0.776795, 0.541039], + [0.996571, 0.783977, 0.547233], + [0.996369, 0.791167, 0.553499], + [0.996162, 0.798348, 0.559820], + [0.995932, 0.805527, 0.566202], + [0.995680, 0.812706, 0.572645], + [0.995424, 0.819875, 0.579140], + [0.995131, 0.827052, 0.585701], + [0.994851, 0.834213, 0.592307], + [0.994524, 0.841387, 0.598983], + [0.994222, 0.848540, 0.605696], + [0.993866, 0.855711, 0.612482], + [0.993545, 0.862859, 0.619299], + [0.993170, 0.870024, 0.626189], + [0.992831, 0.877168, 0.633109], + [0.992440, 0.884330, 0.640099], + [0.992089, 0.891470, 0.647116], + [0.991688, 0.898627, 0.654202], + [0.991332, 0.905763, 0.661309], + [0.990930, 0.912915, 0.668481], + [0.990570, 0.920049, 0.675675], + [0.990175, 0.927196, 0.682926], + [0.989815, 0.934329, 0.690198], + [0.989434, 0.941470, 0.697519], + [0.989077, 0.948604, 0.704863], + [0.988717, 0.955742, 0.712242], + [0.988367, 0.962878, 0.719649], + [0.988033, 0.970012, 0.727077], + [0.987691, 0.977154, 0.734536], + [0.987387, 0.984288, 0.742002], + [0.987053, 0.991438, 0.749504]] + +_inferno_data = [[0.001462, 0.000466, 0.013866], + [0.002267, 0.001270, 0.018570], + [0.003299, 0.002249, 0.024239], + [0.004547, 0.003392, 0.030909], + [0.006006, 0.004692, 0.038558], + [0.007676, 0.006136, 0.046836], + [0.009561, 0.007713, 0.055143], + [0.011663, 0.009417, 0.063460], + [0.013995, 0.011225, 0.071862], + [0.016561, 0.013136, 0.080282], + [0.019373, 0.015133, 0.088767], + [0.022447, 0.017199, 0.097327], + [0.025793, 0.019331, 0.105930], + [0.029432, 0.021503, 0.114621], + [0.033385, 0.023702, 0.123397], + [0.037668, 0.025921, 0.132232], + [0.042253, 0.028139, 0.141141], + [0.046915, 0.030324, 0.150164], + [0.051644, 0.032474, 0.159254], + [0.056449, 0.034569, 0.168414], + [0.061340, 0.036590, 0.177642], + [0.066331, 0.038504, 0.186962], + [0.071429, 0.040294, 0.196354], + [0.076637, 0.041905, 0.205799], + [0.081962, 0.043328, 0.215289], + [0.087411, 0.044556, 0.224813], + [0.092990, 0.045583, 0.234358], + [0.098702, 0.046402, 0.243904], + [0.104551, 0.047008, 0.253430], + [0.110536, 0.047399, 0.262912], + [0.116656, 0.047574, 0.272321], + [0.122908, 0.047536, 0.281624], + [0.129285, 0.047293, 0.290788], + [0.135778, 0.046856, 0.299776], + [0.142378, 0.046242, 0.308553], + [0.149073, 0.045468, 0.317085], + [0.155850, 0.044559, 0.325338], + [0.162689, 0.043554, 0.333277], + [0.169575, 0.042489, 0.340874], + [0.176493, 0.041402, 0.348111], + [0.183429, 0.040329, 0.354971], + [0.190367, 0.039309, 0.361447], + [0.197297, 0.038400, 0.367535], + [0.204209, 0.037632, 0.373238], + [0.211095, 0.037030, 0.378563], + [0.217949, 0.036615, 0.383522], + [0.224763, 0.036405, 0.388129], + [0.231538, 0.036405, 0.392400], + [0.238273, 0.036621, 0.396353], + [0.244967, 0.037055, 0.400007], + [0.251620, 0.037705, 0.403378], + [0.258234, 0.038571, 0.406485], + [0.264810, 0.039647, 0.409345], + [0.271347, 0.040922, 0.411976], + [0.277850, 0.042353, 0.414392], + [0.284321, 0.043933, 0.416608], + [0.290763, 0.045644, 0.418637], + [0.297178, 0.047470, 0.420491], + [0.303568, 0.049396, 0.422182], + [0.309935, 0.051407, 0.423721], + [0.316282, 0.053490, 0.425116], + [0.322610, 0.055634, 0.426377], + [0.328921, 0.057827, 0.427511], + [0.335217, 0.060060, 0.428524], + [0.341500, 0.062325, 0.429425], + [0.347771, 0.064616, 0.430217], + [0.354032, 0.066925, 0.430906], + [0.360284, 0.069247, 0.431497], + [0.366529, 0.071579, 0.431994], + [0.372768, 0.073915, 0.432400], + [0.379001, 0.076253, 0.432719], + [0.385228, 0.078591, 0.432955], + [0.391453, 0.080927, 0.433109], + [0.397674, 0.083257, 0.433183], + [0.403894, 0.085580, 0.433179], + [0.410113, 0.087896, 0.433098], + [0.416331, 0.090203, 0.432943], + [0.422549, 0.092501, 0.432714], + [0.428768, 0.094790, 0.432412], + [0.434987, 0.097069, 0.432039], + [0.441207, 0.099338, 0.431594], + [0.447428, 0.101597, 0.431080], + [0.453651, 0.103848, 0.430498], + [0.459875, 0.106089, 0.429846], + [0.466100, 0.108322, 0.429125], + [0.472328, 0.110547, 0.428334], + [0.478558, 0.112764, 0.427475], + [0.484789, 0.114974, 0.426548], + [0.491022, 0.117179, 0.425552], + [0.497257, 0.119379, 0.424488], + [0.503493, 0.121575, 0.423356], + [0.509730, 0.123769, 0.422156], + [0.515967, 0.125960, 0.420887], + [0.522206, 0.128150, 0.419549], + [0.528444, 0.130341, 0.418142], + [0.534683, 0.132534, 0.416667], + [0.540920, 0.134729, 0.415123], + [0.547157, 0.136929, 0.413511], + [0.553392, 0.139134, 0.411829], + [0.559624, 0.141346, 0.410078], + [0.565854, 0.143567, 0.408258], + [0.572081, 0.145797, 0.406369], + [0.578304, 0.148039, 0.404411], + [0.584521, 0.150294, 0.402385], + [0.590734, 0.152563, 0.400290], + [0.596940, 0.154848, 0.398125], + [0.603139, 0.157151, 0.395891], + [0.609330, 0.159474, 0.393589], + [0.615513, 0.161817, 0.391219], + [0.621685, 0.164184, 0.388781], + [0.627847, 0.166575, 0.386276], + [0.633998, 0.168992, 0.383704], + [0.640135, 0.171438, 0.381065], + [0.646260, 0.173914, 0.378359], + [0.652369, 0.176421, 0.375586], + [0.658463, 0.178962, 0.372748], + [0.664540, 0.181539, 0.369846], + [0.670599, 0.184153, 0.366879], + [0.676638, 0.186807, 0.363849], + [0.682656, 0.189501, 0.360757], + [0.688653, 0.192239, 0.357603], + [0.694627, 0.195021, 0.354388], + [0.700576, 0.197851, 0.351113], + [0.706500, 0.200728, 0.347777], + [0.712396, 0.203656, 0.344383], + [0.718264, 0.206636, 0.340931], + [0.724103, 0.209670, 0.337424], + [0.729909, 0.212759, 0.333861], + [0.735683, 0.215906, 0.330245], + [0.741423, 0.219112, 0.326576], + [0.747127, 0.222378, 0.322856], + [0.752794, 0.225706, 0.319085], + [0.758422, 0.229097, 0.315266], + [0.764010, 0.232554, 0.311399], + [0.769556, 0.236077, 0.307485], + [0.775059, 0.239667, 0.303526], + [0.780517, 0.243327, 0.299523], + [0.785929, 0.247056, 0.295477], + [0.791293, 0.250856, 0.291390], + [0.796607, 0.254728, 0.287264], + [0.801871, 0.258674, 0.283099], + [0.807082, 0.262692, 0.278898], + [0.812239, 0.266786, 0.274661], + [0.817341, 0.270954, 0.270390], + [0.822386, 0.275197, 0.266085], + [0.827372, 0.279517, 0.261750], + [0.832299, 0.283913, 0.257383], + [0.837165, 0.288385, 0.252988], + [0.841969, 0.292933, 0.248564], + [0.846709, 0.297559, 0.244113], + [0.851384, 0.302260, 0.239636], + [0.855992, 0.307038, 0.235133], + [0.860533, 0.311892, 0.230606], + [0.865006, 0.316822, 0.226055], + [0.869409, 0.321827, 0.221482], + [0.873741, 0.326906, 0.216886], + [0.878001, 0.332060, 0.212268], + [0.882188, 0.337287, 0.207628], + [0.886302, 0.342586, 0.202968], + [0.890341, 0.347957, 0.198286], + [0.894305, 0.353399, 0.193584], + [0.898192, 0.358911, 0.188860], + [0.902003, 0.364492, 0.184116], + [0.905735, 0.370140, 0.179350], + [0.909390, 0.375856, 0.174563], + [0.912966, 0.381636, 0.169755], + [0.916462, 0.387481, 0.164924], + [0.919879, 0.393389, 0.160070], + [0.923215, 0.399359, 0.155193], + [0.926470, 0.405389, 0.150292], + [0.929644, 0.411479, 0.145367], + [0.932737, 0.417627, 0.140417], + [0.935747, 0.423831, 0.135440], + [0.938675, 0.430091, 0.130438], + [0.941521, 0.436405, 0.125409], + [0.944285, 0.442772, 0.120354], + [0.946965, 0.449191, 0.115272], + [0.949562, 0.455660, 0.110164], + [0.952075, 0.462178, 0.105031], + [0.954506, 0.468744, 0.099874], + [0.956852, 0.475356, 0.094695], + [0.959114, 0.482014, 0.089499], + [0.961293, 0.488716, 0.084289], + [0.963387, 0.495462, 0.079073], + [0.965397, 0.502249, 0.073859], + [0.967322, 0.509078, 0.068659], + [0.969163, 0.515946, 0.063488], + [0.970919, 0.522853, 0.058367], + [0.972590, 0.529798, 0.053324], + [0.974176, 0.536780, 0.048392], + [0.975677, 0.543798, 0.043618], + [0.977092, 0.550850, 0.039050], + [0.978422, 0.557937, 0.034931], + [0.979666, 0.565057, 0.031409], + [0.980824, 0.572209, 0.028508], + [0.981895, 0.579392, 0.026250], + [0.982881, 0.586606, 0.024661], + [0.983779, 0.593849, 0.023770], + [0.984591, 0.601122, 0.023606], + [0.985315, 0.608422, 0.024202], + [0.985952, 0.615750, 0.025592], + [0.986502, 0.623105, 0.027814], + [0.986964, 0.630485, 0.030908], + [0.987337, 0.637890, 0.034916], + [0.987622, 0.645320, 0.039886], + [0.987819, 0.652773, 0.045581], + [0.987926, 0.660250, 0.051750], + [0.987945, 0.667748, 0.058329], + [0.987874, 0.675267, 0.065257], + [0.987714, 0.682807, 0.072489], + [0.987464, 0.690366, 0.079990], + [0.987124, 0.697944, 0.087731], + [0.986694, 0.705540, 0.095694], + [0.986175, 0.713153, 0.103863], + [0.985566, 0.720782, 0.112229], + [0.984865, 0.728427, 0.120785], + [0.984075, 0.736087, 0.129527], + [0.983196, 0.743758, 0.138453], + [0.982228, 0.751442, 0.147565], + [0.981173, 0.759135, 0.156863], + [0.980032, 0.766837, 0.166353], + [0.978806, 0.774545, 0.176037], + [0.977497, 0.782258, 0.185923], + [0.976108, 0.789974, 0.196018], + [0.974638, 0.797692, 0.206332], + [0.973088, 0.805409, 0.216877], + [0.971468, 0.813122, 0.227658], + [0.969783, 0.820825, 0.238686], + [0.968041, 0.828515, 0.249972], + [0.966243, 0.836191, 0.261534], + [0.964394, 0.843848, 0.273391], + [0.962517, 0.851476, 0.285546], + [0.960626, 0.859069, 0.298010], + [0.958720, 0.866624, 0.310820], + [0.956834, 0.874129, 0.323974], + [0.954997, 0.881569, 0.337475], + [0.953215, 0.888942, 0.351369], + [0.951546, 0.896226, 0.365627], + [0.950018, 0.903409, 0.380271], + [0.948683, 0.910473, 0.395289], + [0.947594, 0.917399, 0.410665], + [0.946809, 0.924168, 0.426373], + [0.946392, 0.930761, 0.442367], + [0.946403, 0.937159, 0.458592], + [0.946903, 0.943348, 0.474970], + [0.947937, 0.949318, 0.491426], + [0.949545, 0.955063, 0.507860], + [0.951740, 0.960587, 0.524203], + [0.954529, 0.965896, 0.540361], + [0.957896, 0.971003, 0.556275], + [0.961812, 0.975924, 0.571925], + [0.966249, 0.980678, 0.587206], + [0.971162, 0.985282, 0.602154], + [0.976511, 0.989753, 0.616760], + [0.982257, 0.994109, 0.631017], + [0.988362, 0.998364, 0.644924]] + +_plasma_data = [[0.050383, 0.029803, 0.527975], + [0.063536, 0.028426, 0.533124], + [0.075353, 0.027206, 0.538007], + [0.086222, 0.026125, 0.542658], + [0.096379, 0.025165, 0.547103], + [0.105980, 0.024309, 0.551368], + [0.115124, 0.023556, 0.555468], + [0.123903, 0.022878, 0.559423], + [0.132381, 0.022258, 0.563250], + [0.140603, 0.021687, 0.566959], + [0.148607, 0.021154, 0.570562], + [0.156421, 0.020651, 0.574065], + [0.164070, 0.020171, 0.577478], + [0.171574, 0.019706, 0.580806], + [0.178950, 0.019252, 0.584054], + [0.186213, 0.018803, 0.587228], + [0.193374, 0.018354, 0.590330], + [0.200445, 0.017902, 0.593364], + [0.207435, 0.017442, 0.596333], + [0.214350, 0.016973, 0.599239], + [0.221197, 0.016497, 0.602083], + [0.227983, 0.016007, 0.604867], + [0.234715, 0.015502, 0.607592], + [0.241396, 0.014979, 0.610259], + [0.248032, 0.014439, 0.612868], + [0.254627, 0.013882, 0.615419], + [0.261183, 0.013308, 0.617911], + [0.267703, 0.012716, 0.620346], + [0.274191, 0.012109, 0.622722], + [0.280648, 0.011488, 0.625038], + [0.287076, 0.010855, 0.627295], + [0.293478, 0.010213, 0.629490], + [0.299855, 0.009561, 0.631624], + [0.306210, 0.008902, 0.633694], + [0.312543, 0.008239, 0.635700], + [0.318856, 0.007576, 0.637640], + [0.325150, 0.006915, 0.639512], + [0.331426, 0.006261, 0.641316], + [0.337683, 0.005618, 0.643049], + [0.343925, 0.004991, 0.644710], + [0.350150, 0.004382, 0.646298], + [0.356359, 0.003798, 0.647810], + [0.362553, 0.003243, 0.649245], + [0.368733, 0.002724, 0.650601], + [0.374897, 0.002245, 0.651876], + [0.381047, 0.001814, 0.653068], + [0.387183, 0.001434, 0.654177], + [0.393304, 0.001114, 0.655199], + [0.399411, 0.000859, 0.656133], + [0.405503, 0.000678, 0.656977], + [0.411580, 0.000577, 0.657730], + [0.417642, 0.000564, 0.658390], + [0.423689, 0.000646, 0.658956], + [0.429719, 0.000831, 0.659425], + [0.435734, 0.001127, 0.659797], + [0.441732, 0.001540, 0.660069], + [0.447714, 0.002080, 0.660240], + [0.453677, 0.002755, 0.660310], + [0.459623, 0.003574, 0.660277], + [0.465550, 0.004545, 0.660139], + [0.471457, 0.005678, 0.659897], + [0.477344, 0.006980, 0.659549], + [0.483210, 0.008460, 0.659095], + [0.489055, 0.010127, 0.658534], + [0.494877, 0.011990, 0.657865], + [0.500678, 0.014055, 0.657088], + [0.506454, 0.016333, 0.656202], + [0.512206, 0.018833, 0.655209], + [0.517933, 0.021563, 0.654109], + [0.523633, 0.024532, 0.652901], + [0.529306, 0.027747, 0.651586], + [0.534952, 0.031217, 0.650165], + [0.540570, 0.034950, 0.648640], + [0.546157, 0.038954, 0.647010], + [0.551715, 0.043136, 0.645277], + [0.557243, 0.047331, 0.643443], + [0.562738, 0.051545, 0.641509], + [0.568201, 0.055778, 0.639477], + [0.573632, 0.060028, 0.637349], + [0.579029, 0.064296, 0.635126], + [0.584391, 0.068579, 0.632812], + [0.589719, 0.072878, 0.630408], + [0.595011, 0.077190, 0.627917], + [0.600266, 0.081516, 0.625342], + [0.605485, 0.085854, 0.622686], + [0.610667, 0.090204, 0.619951], + [0.615812, 0.094564, 0.617140], + [0.620919, 0.098934, 0.614257], + [0.625987, 0.103312, 0.611305], + [0.631017, 0.107699, 0.608287], + [0.636008, 0.112092, 0.605205], + [0.640959, 0.116492, 0.602065], + [0.645872, 0.120898, 0.598867], + [0.650746, 0.125309, 0.595617], + [0.655580, 0.129725, 0.592317], + [0.660374, 0.134144, 0.588971], + [0.665129, 0.138566, 0.585582], + [0.669845, 0.142992, 0.582154], + [0.674522, 0.147419, 0.578688], + [0.679160, 0.151848, 0.575189], + [0.683758, 0.156278, 0.571660], + [0.688318, 0.160709, 0.568103], + [0.692840, 0.165141, 0.564522], + [0.697324, 0.169573, 0.560919], + [0.701769, 0.174005, 0.557296], + [0.706178, 0.178437, 0.553657], + [0.710549, 0.182868, 0.550004], + [0.714883, 0.187299, 0.546338], + [0.719181, 0.191729, 0.542663], + [0.723444, 0.196158, 0.538981], + [0.727670, 0.200586, 0.535293], + [0.731862, 0.205013, 0.531601], + [0.736019, 0.209439, 0.527908], + [0.740143, 0.213864, 0.524216], + [0.744232, 0.218288, 0.520524], + [0.748289, 0.222711, 0.516834], + [0.752312, 0.227133, 0.513149], + [0.756304, 0.231555, 0.509468], + [0.760264, 0.235976, 0.505794], + [0.764193, 0.240396, 0.502126], + [0.768090, 0.244817, 0.498465], + [0.771958, 0.249237, 0.494813], + [0.775796, 0.253658, 0.491171], + [0.779604, 0.258078, 0.487539], + [0.783383, 0.262500, 0.483918], + [0.787133, 0.266922, 0.480307], + [0.790855, 0.271345, 0.476706], + [0.794549, 0.275770, 0.473117], + [0.798216, 0.280197, 0.469538], + [0.801855, 0.284626, 0.465971], + [0.805467, 0.289057, 0.462415], + [0.809052, 0.293491, 0.458870], + [0.812612, 0.297928, 0.455338], + [0.816144, 0.302368, 0.451816], + [0.819651, 0.306812, 0.448306], + [0.823132, 0.311261, 0.444806], + [0.826588, 0.315714, 0.441316], + [0.830018, 0.320172, 0.437836], + [0.833422, 0.324635, 0.434366], + [0.836801, 0.329105, 0.430905], + [0.840155, 0.333580, 0.427455], + [0.843484, 0.338062, 0.424013], + [0.846788, 0.342551, 0.420579], + [0.850066, 0.347048, 0.417153], + [0.853319, 0.351553, 0.413734], + [0.856547, 0.356066, 0.410322], + [0.859750, 0.360588, 0.406917], + [0.862927, 0.365119, 0.403519], + [0.866078, 0.369660, 0.400126], + [0.869203, 0.374212, 0.396738], + [0.872303, 0.378774, 0.393355], + [0.875376, 0.383347, 0.389976], + [0.878423, 0.387932, 0.386600], + [0.881443, 0.392529, 0.383229], + [0.884436, 0.397139, 0.379860], + [0.887402, 0.401762, 0.376494], + [0.890340, 0.406398, 0.373130], + [0.893250, 0.411048, 0.369768], + [0.896131, 0.415712, 0.366407], + [0.898984, 0.420392, 0.363047], + [0.901807, 0.425087, 0.359688], + [0.904601, 0.429797, 0.356329], + [0.907365, 0.434524, 0.352970], + [0.910098, 0.439268, 0.349610], + [0.912800, 0.444029, 0.346251], + [0.915471, 0.448807, 0.342890], + [0.918109, 0.453603, 0.339529], + [0.920714, 0.458417, 0.336166], + [0.923287, 0.463251, 0.332801], + [0.925825, 0.468103, 0.329435], + [0.928329, 0.472975, 0.326067], + [0.930798, 0.477867, 0.322697], + [0.933232, 0.482780, 0.319325], + [0.935630, 0.487712, 0.315952], + [0.937990, 0.492667, 0.312575], + [0.940313, 0.497642, 0.309197], + [0.942598, 0.502639, 0.305816], + [0.944844, 0.507658, 0.302433], + [0.947051, 0.512699, 0.299049], + [0.949217, 0.517763, 0.295662], + [0.951344, 0.522850, 0.292275], + [0.953428, 0.527960, 0.288883], + [0.955470, 0.533093, 0.285490], + [0.957469, 0.538250, 0.282096], + [0.959424, 0.543431, 0.278701], + [0.961336, 0.548636, 0.275305], + [0.963203, 0.553865, 0.271909], + [0.965024, 0.559118, 0.268513], + [0.966798, 0.564396, 0.265118], + [0.968526, 0.569700, 0.261721], + [0.970205, 0.575028, 0.258325], + [0.971835, 0.580382, 0.254931], + [0.973416, 0.585761, 0.251540], + [0.974947, 0.591165, 0.248151], + [0.976428, 0.596595, 0.244767], + [0.977856, 0.602051, 0.241387], + [0.979233, 0.607532, 0.238013], + [0.980556, 0.613039, 0.234646], + [0.981826, 0.618572, 0.231287], + [0.983041, 0.624131, 0.227937], + [0.984199, 0.629718, 0.224595], + [0.985301, 0.635330, 0.221265], + [0.986345, 0.640969, 0.217948], + [0.987332, 0.646633, 0.214648], + [0.988260, 0.652325, 0.211364], + [0.989128, 0.658043, 0.208100], + [0.989935, 0.663787, 0.204859], + [0.990681, 0.669558, 0.201642], + [0.991365, 0.675355, 0.198453], + [0.991985, 0.681179, 0.195295], + [0.992541, 0.687030, 0.192170], + [0.993032, 0.692907, 0.189084], + [0.993456, 0.698810, 0.186041], + [0.993814, 0.704741, 0.183043], + [0.994103, 0.710698, 0.180097], + [0.994324, 0.716681, 0.177208], + [0.994474, 0.722691, 0.174381], + [0.994553, 0.728728, 0.171622], + [0.994561, 0.734791, 0.168938], + [0.994495, 0.740880, 0.166335], + [0.994355, 0.746995, 0.163821], + [0.994141, 0.753137, 0.161404], + [0.993851, 0.759304, 0.159092], + [0.993482, 0.765499, 0.156891], + [0.993033, 0.771720, 0.154808], + [0.992505, 0.777967, 0.152855], + [0.991897, 0.784239, 0.151042], + [0.991209, 0.790537, 0.149377], + [0.990439, 0.796859, 0.147870], + [0.989587, 0.803205, 0.146529], + [0.988648, 0.809579, 0.145357], + [0.987621, 0.815978, 0.144363], + [0.986509, 0.822401, 0.143557], + [0.985314, 0.828846, 0.142945], + [0.984031, 0.835315, 0.142528], + [0.982653, 0.841812, 0.142303], + [0.981190, 0.848329, 0.142279], + [0.979644, 0.854866, 0.142453], + [0.977995, 0.861432, 0.142808], + [0.976265, 0.868016, 0.143351], + [0.974443, 0.874622, 0.144061], + [0.972530, 0.881250, 0.144923], + [0.970533, 0.887896, 0.145919], + [0.968443, 0.894564, 0.147014], + [0.966271, 0.901249, 0.148180], + [0.964021, 0.907950, 0.149370], + [0.961681, 0.914672, 0.150520], + [0.959276, 0.921407, 0.151566], + [0.956808, 0.928152, 0.152409], + [0.954287, 0.934908, 0.152921], + [0.951726, 0.941671, 0.152925], + [0.949151, 0.948435, 0.152178], + [0.946602, 0.955190, 0.150328], + [0.944152, 0.961916, 0.146861], + [0.941896, 0.968590, 0.140956], + [0.940015, 0.975158, 0.131326]] + +_viridis_data = [[0.267004, 0.004874, 0.329415], + [0.268510, 0.009605, 0.335427], + [0.269944, 0.014625, 0.341379], + [0.271305, 0.019942, 0.347269], + [0.272594, 0.025563, 0.353093], + [0.273809, 0.031497, 0.358853], + [0.274952, 0.037752, 0.364543], + [0.276022, 0.044167, 0.370164], + [0.277018, 0.050344, 0.375715], + [0.277941, 0.056324, 0.381191], + [0.278791, 0.062145, 0.386592], + [0.279566, 0.067836, 0.391917], + [0.280267, 0.073417, 0.397163], + [0.280894, 0.078907, 0.402329], + [0.281446, 0.084320, 0.407414], + [0.281924, 0.089666, 0.412415], + [0.282327, 0.094955, 0.417331], + [0.282656, 0.100196, 0.422160], + [0.282910, 0.105393, 0.426902], + [0.283091, 0.110553, 0.431554], + [0.283197, 0.115680, 0.436115], + [0.283229, 0.120777, 0.440584], + [0.283187, 0.125848, 0.444960], + [0.283072, 0.130895, 0.449241], + [0.282884, 0.135920, 0.453427], + [0.282623, 0.140926, 0.457517], + [0.282290, 0.145912, 0.461510], + [0.281887, 0.150881, 0.465405], + [0.281412, 0.155834, 0.469201], + [0.280868, 0.160771, 0.472899], + [0.280255, 0.165693, 0.476498], + [0.279574, 0.170599, 0.479997], + [0.278826, 0.175490, 0.483397], + [0.278012, 0.180367, 0.486697], + [0.277134, 0.185228, 0.489898], + [0.276194, 0.190074, 0.493001], + [0.275191, 0.194905, 0.496005], + [0.274128, 0.199721, 0.498911], + [0.273006, 0.204520, 0.501721], + [0.271828, 0.209303, 0.504434], + [0.270595, 0.214069, 0.507052], + [0.269308, 0.218818, 0.509577], + [0.267968, 0.223549, 0.512008], + [0.266580, 0.228262, 0.514349], + [0.265145, 0.232956, 0.516599], + [0.263663, 0.237631, 0.518762], + [0.262138, 0.242286, 0.520837], + [0.260571, 0.246922, 0.522828], + [0.258965, 0.251537, 0.524736], + [0.257322, 0.256130, 0.526563], + [0.255645, 0.260703, 0.528312], + [0.253935, 0.265254, 0.529983], + [0.252194, 0.269783, 0.531579], + [0.250425, 0.274290, 0.533103], + [0.248629, 0.278775, 0.534556], + [0.246811, 0.283237, 0.535941], + [0.244972, 0.287675, 0.537260], + [0.243113, 0.292092, 0.538516], + [0.241237, 0.296485, 0.539709], + [0.239346, 0.300855, 0.540844], + [0.237441, 0.305202, 0.541921], + [0.235526, 0.309527, 0.542944], + [0.233603, 0.313828, 0.543914], + [0.231674, 0.318106, 0.544834], + [0.229739, 0.322361, 0.545706], + [0.227802, 0.326594, 0.546532], + [0.225863, 0.330805, 0.547314], + [0.223925, 0.334994, 0.548053], + [0.221989, 0.339161, 0.548752], + [0.220057, 0.343307, 0.549413], + [0.218130, 0.347432, 0.550038], + [0.216210, 0.351535, 0.550627], + [0.214298, 0.355619, 0.551184], + [0.212395, 0.359683, 0.551710], + [0.210503, 0.363727, 0.552206], + [0.208623, 0.367752, 0.552675], + [0.206756, 0.371758, 0.553117], + [0.204903, 0.375746, 0.553533], + [0.203063, 0.379716, 0.553925], + [0.201239, 0.383670, 0.554294], + [0.199430, 0.387607, 0.554642], + [0.197636, 0.391528, 0.554969], + [0.195860, 0.395433, 0.555276], + [0.194100, 0.399323, 0.555565], + [0.192357, 0.403199, 0.555836], + [0.190631, 0.407061, 0.556089], + [0.188923, 0.410910, 0.556326], + [0.187231, 0.414746, 0.556547], + [0.185556, 0.418570, 0.556753], + [0.183898, 0.422383, 0.556944], + [0.182256, 0.426184, 0.557120], + [0.180629, 0.429975, 0.557282], + [0.179019, 0.433756, 0.557430], + [0.177423, 0.437527, 0.557565], + [0.175841, 0.441290, 0.557685], + [0.174274, 0.445044, 0.557792], + [0.172719, 0.448791, 0.557885], + [0.171176, 0.452530, 0.557965], + [0.169646, 0.456262, 0.558030], + [0.168126, 0.459988, 0.558082], + [0.166617, 0.463708, 0.558119], + [0.165117, 0.467423, 0.558141], + [0.163625, 0.471133, 0.558148], + [0.162142, 0.474838, 0.558140], + [0.160665, 0.478540, 0.558115], + [0.159194, 0.482237, 0.558073], + [0.157729, 0.485932, 0.558013], + [0.156270, 0.489624, 0.557936], + [0.154815, 0.493313, 0.557840], + [0.153364, 0.497000, 0.557724], + [0.151918, 0.500685, 0.557587], + [0.150476, 0.504369, 0.557430], + [0.149039, 0.508051, 0.557250], + [0.147607, 0.511733, 0.557049], + [0.146180, 0.515413, 0.556823], + [0.144759, 0.519093, 0.556572], + [0.143343, 0.522773, 0.556295], + [0.141935, 0.526453, 0.555991], + [0.140536, 0.530132, 0.555659], + [0.139147, 0.533812, 0.555298], + [0.137770, 0.537492, 0.554906], + [0.136408, 0.541173, 0.554483], + [0.135066, 0.544853, 0.554029], + [0.133743, 0.548535, 0.553541], + [0.132444, 0.552216, 0.553018], + [0.131172, 0.555899, 0.552459], + [0.129933, 0.559582, 0.551864], + [0.128729, 0.563265, 0.551229], + [0.127568, 0.566949, 0.550556], + [0.126453, 0.570633, 0.549841], + [0.125394, 0.574318, 0.549086], + [0.124395, 0.578002, 0.548287], + [0.123463, 0.581687, 0.547445], + [0.122606, 0.585371, 0.546557], + [0.121831, 0.589055, 0.545623], + [0.121148, 0.592739, 0.544641], + [0.120565, 0.596422, 0.543611], + [0.120092, 0.600104, 0.542530], + [0.119738, 0.603785, 0.541400], + [0.119512, 0.607464, 0.540218], + [0.119423, 0.611141, 0.538982], + [0.119483, 0.614817, 0.537692], + [0.119699, 0.618490, 0.536347], + [0.120081, 0.622161, 0.534946], + [0.120638, 0.625828, 0.533488], + [0.121380, 0.629492, 0.531973], + [0.122312, 0.633153, 0.530398], + [0.123444, 0.636809, 0.528763], + [0.124780, 0.640461, 0.527068], + [0.126326, 0.644107, 0.525311], + [0.128087, 0.647749, 0.523491], + [0.130067, 0.651384, 0.521608], + [0.132268, 0.655014, 0.519661], + [0.134692, 0.658636, 0.517649], + [0.137339, 0.662252, 0.515571], + [0.140210, 0.665859, 0.513427], + [0.143303, 0.669459, 0.511215], + [0.146616, 0.673050, 0.508936], + [0.150148, 0.676631, 0.506589], + [0.153894, 0.680203, 0.504172], + [0.157851, 0.683765, 0.501686], + [0.162016, 0.687316, 0.499129], + [0.166383, 0.690856, 0.496502], + [0.170948, 0.694384, 0.493803], + [0.175707, 0.697900, 0.491033], + [0.180653, 0.701402, 0.488189], + [0.185783, 0.704891, 0.485273], + [0.191090, 0.708366, 0.482284], + [0.196571, 0.711827, 0.479221], + [0.202219, 0.715272, 0.476084], + [0.208030, 0.718701, 0.472873], + [0.214000, 0.722114, 0.469588], + [0.220124, 0.725509, 0.466226], + [0.226397, 0.728888, 0.462789], + [0.232815, 0.732247, 0.459277], + [0.239374, 0.735588, 0.455688], + [0.246070, 0.738910, 0.452024], + [0.252899, 0.742211, 0.448284], + [0.259857, 0.745492, 0.444467], + [0.266941, 0.748751, 0.440573], + [0.274149, 0.751988, 0.436601], + [0.281477, 0.755203, 0.432552], + [0.288921, 0.758394, 0.428426], + [0.296479, 0.761561, 0.424223], + [0.304148, 0.764704, 0.419943], + [0.311925, 0.767822, 0.415586], + [0.319809, 0.770914, 0.411152], + [0.327796, 0.773980, 0.406640], + [0.335885, 0.777018, 0.402049], + [0.344074, 0.780029, 0.397381], + [0.352360, 0.783011, 0.392636], + [0.360741, 0.785964, 0.387814], + [0.369214, 0.788888, 0.382914], + [0.377779, 0.791781, 0.377939], + [0.386433, 0.794644, 0.372886], + [0.395174, 0.797475, 0.367757], + [0.404001, 0.800275, 0.362552], + [0.412913, 0.803041, 0.357269], + [0.421908, 0.805774, 0.351910], + [0.430983, 0.808473, 0.346476], + [0.440137, 0.811138, 0.340967], + [0.449368, 0.813768, 0.335384], + [0.458674, 0.816363, 0.329727], + [0.468053, 0.818921, 0.323998], + [0.477504, 0.821444, 0.318195], + [0.487026, 0.823929, 0.312321], + [0.496615, 0.826376, 0.306377], + [0.506271, 0.828786, 0.300362], + [0.515992, 0.831158, 0.294279], + [0.525776, 0.833491, 0.288127], + [0.535621, 0.835785, 0.281908], + [0.545524, 0.838039, 0.275626], + [0.555484, 0.840254, 0.269281], + [0.565498, 0.842430, 0.262877], + [0.575563, 0.844566, 0.256415], + [0.585678, 0.846661, 0.249897], + [0.595839, 0.848717, 0.243329], + [0.606045, 0.850733, 0.236712], + [0.616293, 0.852709, 0.230052], + [0.626579, 0.854645, 0.223353], + [0.636902, 0.856542, 0.216620], + [0.647257, 0.858400, 0.209861], + [0.657642, 0.860219, 0.203082], + [0.668054, 0.861999, 0.196293], + [0.678489, 0.863742, 0.189503], + [0.688944, 0.865448, 0.182725], + [0.699415, 0.867117, 0.175971], + [0.709898, 0.868751, 0.169257], + [0.720391, 0.870350, 0.162603], + [0.730889, 0.871916, 0.156029], + [0.741388, 0.873449, 0.149561], + [0.751884, 0.874951, 0.143228], + [0.762373, 0.876424, 0.137064], + [0.772852, 0.877868, 0.131109], + [0.783315, 0.879285, 0.125405], + [0.793760, 0.880678, 0.120005], + [0.804182, 0.882046, 0.114965], + [0.814576, 0.883393, 0.110347], + [0.824940, 0.884720, 0.106217], + [0.835270, 0.886029, 0.102646], + [0.845561, 0.887322, 0.099702], + [0.855810, 0.888601, 0.097452], + [0.866013, 0.889868, 0.095953], + [0.876168, 0.891125, 0.095250], + [0.886271, 0.892374, 0.095374], + [0.896320, 0.893616, 0.096335], + [0.906311, 0.894855, 0.098125], + [0.916242, 0.896091, 0.100717], + [0.926106, 0.897330, 0.104071], + [0.935904, 0.898570, 0.108131], + [0.945636, 0.899815, 0.112838], + [0.955300, 0.901065, 0.118128], + [0.964894, 0.902323, 0.123941], + [0.974417, 0.903590, 0.130215], + [0.983868, 0.904867, 0.136897], + [0.993248, 0.906157, 0.143936]] + + +cmaps = {} +for (name, data) in (('magma', _magma_data), + ('inferno', _inferno_data), + ('plasma', _plasma_data), + ('viridis', _viridis_data)): + + cmaps[name] = ListedColormap(data, name=name) + +magma = cmaps['magma'] +inferno = cmaps['inferno'] +plasma = cmaps['plasma'] +viridis = cmaps['viridis'] diff --git a/silx/gui/plot/MaskToolsWidget.py b/silx/gui/plot/MaskToolsWidget.py new file mode 100644 index 0000000..6407d44 --- /dev/null +++ b/silx/gui/plot/MaskToolsWidget.py @@ -0,0 +1,615 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Widget providing a set of tools to draw masks on a PlotWidget. + +This widget is meant to work with :class:`silx.gui.plot.PlotWidget`. + +- :class:`ImageMask`: Handle mask bitmap update and history +- :class:`MaskToolsWidget`: GUI for :class:`Mask` +- :class:`MaskToolsDockWidget`: DockWidget to integrate in :class:`PlotWindow` +""" +from __future__ import division + + +__authors__ = ["T. Vincent", "P. Knobel"] +__license__ = "MIT" +__date__ = "20/04/2017" + + +import os +import sys +import numpy +import logging + +from silx.image import shapes + +from ._BaseMaskToolsWidget import BaseMask, BaseMaskToolsWidget, BaseMaskToolsDockWidget +from . import items +from .Colors import cursorColorForColormap, rgba +from .. import qt + +from silx.third_party.EdfFile import EdfFile +from silx.third_party.TiffIO import TiffIO + +try: + import fabio +except ImportError: + fabio = None + + +_logger = logging.getLogger(__name__) + + +class ImageMask(BaseMask): + """A 2D mask field with update operations. + + Coords follows (row, column) convention and are in mask array coords. + + This is meant for internal use by :class:`MaskToolsWidget`. + """ + def __init__(self, image=None): + """ + + :param image: :class:`silx.gui.plot.items.ImageBase` instance + """ + BaseMask.__init__(self, image) + + def getDataValues(self): + """Return image data as a 2D or 3D array (if it is a RGBA image). + + :rtype: 2D or 3D numpy.ndarray + """ + return self._dataItem.getData(copy=False) + + def save(self, filename, kind): + """Save current mask in a file + + :param str filename: The file where to save to mask + :param str kind: The kind of file to save in 'edf', 'tif', 'npy', + or 'msk' (if FabIO is installed) + :raise Exception: Raised if the file writing fail + """ + if kind == 'edf': + edfFile = EdfFile(filename, access="w+") + edfFile.WriteImage({}, self.getMask(copy=False), Append=0) + + elif kind == 'tif': + tiffFile = TiffIO(filename, mode='w') + tiffFile.writeImage(self.getMask(copy=False), software='silx') + + elif kind == 'npy': + try: + numpy.save(filename, self.getMask(copy=False)) + except IOError: + raise RuntimeError("Mask file can't be written") + + elif kind == 'msk': + if fabio is None: + raise ImportError("Fit2d mask files can't be written: Fabio module is not available") + try: + data = self.getMask(copy=False) + image = fabio.fabioimage.FabioImage(data=data) + image = image.convert(fabio.fit2dmaskimage.Fit2dMaskImage) + image.save(filename) + except Exception: + _logger.debug("Backtrace", exc_info=True) + raise RuntimeError("Mask file can't be written") + + else: + raise ValueError("Format '%s' is not supported" % kind) + + # Drawing operations + def updateRectangle(self, level, row, col, height, width, mask=True): + """Mask/Unmask a rectangle of the given mask level. + + :param int level: Mask level to update. + :param int row: Starting row of the rectangle + :param int col: Starting column of the rectangle + :param int height: + :param int width: + :param bool mask: True to mask (default), False to unmask. + """ + assert 0 < level < 256 + selection = self._mask[max(0, row):row + height + 1, + max(0, col):col + width + 1] + if mask: + selection[:, :] = level + else: + selection[selection == level] = 0 + self._notify() + + def updatePolygon(self, level, vertices, mask=True): + """Mask/Unmask a polygon of the given mask level. + + :param int level: Mask level to update. + :param vertices: Nx2 array of polygon corners as (row, col) + :param bool mask: True to mask (default), False to unmask. + """ + fill = shapes.polygon_fill_mask(vertices, self._mask.shape) + if mask: + self._mask[fill != 0] = level + else: + self._mask[numpy.logical_and(fill != 0, + self._mask == level)] = 0 + self._notify() + + def updatePoints(self, level, rows, cols, mask=True): + """Mask/Unmask points with given coordinates. + + :param int level: Mask level to update. + :param rows: Rows of selected points + :type rows: 1D numpy.ndarray + :param cols: Columns of selected points + :type cols: 1D numpy.ndarray + :param bool mask: True to mask (default), False to unmask. + """ + valid = numpy.logical_and( + numpy.logical_and(rows >= 0, cols >= 0), + numpy.logical_and(rows < self._mask.shape[0], + cols < self._mask.shape[1])) + rows, cols = rows[valid], cols[valid] + + if mask: + self._mask[rows, cols] = level + else: + inMask = self._mask[rows, cols] == level + self._mask[rows[inMask], cols[inMask]] = 0 + self._notify() + + def updateDisk(self, level, crow, ccol, radius, mask=True): + """Mask/Unmask a disk of the given mask level. + + :param int level: Mask level to update. + :param int crow: Disk center row. + :param int ccol: Disk center column. + :param float radius: Radius of the disk in mask array unit + :param bool mask: True to mask (default), False to unmask. + """ + rows, cols = shapes.circle_fill(crow, ccol, radius) + self.updatePoints(level, rows, cols, mask) + + def updateLine(self, level, row0, col0, row1, col1, width, mask=True): + """Mask/Unmask a line of the given mask level. + + :param int level: Mask level to update. + :param int row0: Row of the starting point. + :param int col0: Column of the starting point. + :param int row1: Row of the end point. + :param int col1: Column of the end point. + :param int width: Width of the line in mask array unit. + :param bool mask: True to mask (default), False to unmask. + """ + rows, cols = shapes.draw_line(row0, col0, row1, col1, width) + self.updatePoints(level, rows, cols, mask) + + +class MaskToolsWidget(BaseMaskToolsWidget): + """Widget with tools for drawing mask on an image in a PlotWidget.""" + + _maxLevelNumber = 255 + + def __init__(self, parent=None, plot=None): + self._origin = (0., 0.) # Mask origin in plot + self._scale = (1., 1.) # Mask scale in plot + self._z = 1 # Mask layer in plot + self._data = numpy.zeros((0, 0), dtype=numpy.uint8) # Store image + + self._mask = ImageMask() + + super(MaskToolsWidget, self).__init__(parent, plot) + + self._initWidgets() + + def setSelectionMask(self, mask, copy=True): + """Set the mask to a new array. + + :param numpy.ndarray mask: The array to use for the mask. + :type mask: numpy.ndarray of uint8 of dimension 2, C-contiguous. + Array of other types are converted. + :param bool copy: True (the default) to copy the array, + False to use it as is if possible. + :return: None if failed, shape of mask as 2-tuple if successful. + The mask can be cropped or padded to fit active image, + the returned shape is that of the active image. + """ + mask = numpy.array(mask, copy=False, dtype=numpy.uint8) + if len(mask.shape) != 2: + _logger.error('Not an image, shape: %d', len(mask.shape)) + return None + + if self._data.shape[0:2] == (0, 0) or mask.shape == self._data.shape[0:2]: + self._mask.setMask(mask, copy=copy) + self._mask.commit() + return mask.shape + else: + _logger.warning('Mask has not the same size as current image.' + ' Mask will be cropped or padded to fit image' + ' dimensions. %s != %s', + str(mask.shape), str(self._data.shape)) + resizedMask = numpy.zeros(self._data.shape[0:2], + dtype=numpy.uint8) + height = min(self._data.shape[0], mask.shape[0]) + width = min(self._data.shape[1], mask.shape[1]) + resizedMask[:height, :width] = mask[:height, :width] + self._mask.setMask(resizedMask, copy=False) + self._mask.commit() + return resizedMask.shape + + # Handle mask refresh on the plot + def _updatePlotMask(self): + """Update mask image in plot""" + mask = self.getSelectionMask(copy=False) + if len(mask): + self.plot.addImage(mask, legend=self._maskName, + colormap=self._colormap, + origin=self._origin, + scale=self._scale, + z=self._z, + replace=False, resetzoom=False) + elif self.plot.getImage(self._maskName): + self.plot.remove(self._maskName, kind='image') + + def showEvent(self, event): + try: + self.plot.sigActiveImageChanged.disconnect( + self._activeImageChangedAfterCare) + except (RuntimeError, TypeError): + pass + self._activeImageChanged() # Init mask + enable/disable widget + self.plot.sigActiveImageChanged.connect(self._activeImageChanged) + + def hideEvent(self, event): + self.plot.sigActiveImageChanged.disconnect(self._activeImageChanged) + if not self.browseAction.isChecked(): + self.browseAction.trigger() # Disable drawing tool + + if len(self.getSelectionMask(copy=False)): + self.plot.sigActiveImageChanged.connect( + self._activeImageChangedAfterCare) + + def _setOverlayColorForImage(self, image): + """Set the color of overlay adapted to image + + :param image: :class:`.items.ImageBase` object to set color for. + """ + if isinstance(image, items.ColormapMixIn): + colormap = image.getColormap() + self._defaultOverlayColor = rgba( + cursorColorForColormap(colormap['name'])) + else: + self._defaultOverlayColor = rgba('black') + + def _activeImageChangedAfterCare(self, *args): + """Check synchro of active image and mask when mask widget is hidden. + + If active image has no more the same size as the mask, the mask is + removed, otherwise it is adjusted to origin, scale and z. + """ + activeImage = self.plot.getActiveImage() + if activeImage is None or activeImage.getLegend() == self._maskName: + # No active image or active image is the mask... + self.plot.sigActiveImageChanged.disconnect( + self._activeImageChangedAfterCare) + else: + self._setOverlayColorForImage(activeImage) + self._setMaskColors(self.levelSpinBox.value(), + self.transparencySlider.value() / + self.transparencySlider.maximum()) + + self._origin = activeImage.getOrigin() + self._scale = activeImage.getScale() + self._z = activeImage.getZValue() + 1 + self._data = activeImage.getData(copy=False) + if self._data.shape[:2] != self.getSelectionMask(copy=False).shape: + # Image has not the same size, remove mask and stop listening + if self.plot.getImage(self._maskName): + self.plot.remove(self._maskName, kind='image') + + self.plot.sigActiveImageChanged.disconnect( + self._activeImageChangedAfterCare) + else: + # Refresh in case origin, scale, z changed + self._mask.setDataItem(activeImage) + self._updatePlotMask() + + def _activeImageChanged(self, *args): + """Update widget and mask according to active image changes""" + activeImage = self.plot.getActiveImage() + if activeImage is None or activeImage.getLegend() == self._maskName: + # No active image or active image is the mask... + self.setEnabled(False) + + self._data = numpy.zeros((0, 0), dtype=numpy.uint8) + self._mask.reset() + self._mask.commit() + + else: # There is an active image + self.setEnabled(True) + + self._setOverlayColorForImage(activeImage) + + self._setMaskColors(self.levelSpinBox.value(), + self.transparencySlider.value() / + self.transparencySlider.maximum()) + + self._origin = activeImage.getOrigin() + self._scale = activeImage.getScale() + self._z = activeImage.getZValue() + 1 + self._data = activeImage.getData(copy=False) + self._mask.setDataItem(activeImage) + if self._data.shape[:2] != self.getSelectionMask(copy=False).shape: + self._mask.reset(self._data.shape[:2]) + self._mask.commit() + else: + # Refresh in case origin, scale, z changed + self._updatePlotMask() + + # Threshold tools only available for data with colormap + self.thresholdGroup.setEnabled(self._data.ndim == 2) + + self._updateInteractiveMode() + + # Handle whole mask operations + def load(self, filename): + """Load a mask from an image file. + + :param str filename: File name from which to load the mask + :raise Exception: An exception in case of failure + :raise RuntimeWarning: In case the mask was applied but with some + import changes to notice + """ + _, extension = os.path.splitext(filename) + extension = extension.lower()[1:] + + if extension == "npy": + try: + mask = numpy.load(filename) + except IOError: + _logger.error("Can't load filename '%s'", filename) + _logger.debug("Backtrace", exc_info=True) + raise RuntimeError('File "%s" is not a numpy file.', filename) + elif extension == "edf": + try: + mask = EdfFile(filename, access='r').GetData(0) + except Exception as e: + _logger.error("Can't load filename %s", filename) + _logger.debug("Backtrace", exc_info=True) + raise e + elif extension == "msk": + if fabio is None: + raise ImportError("Fit2d mask files can't be read: Fabio module is not available") + try: + mask = fabio.open(filename).data + except Exception as e: + _logger.error("Can't load fit2d mask file") + _logger.debug("Backtrace", exc_info=True) + raise e + else: + msg = "Extension '%s' is not supported." + raise RuntimeError(msg % extension) + + effectiveMaskShape = self.setSelectionMask(mask, copy=False) + if effectiveMaskShape is None: + return + if mask.shape != effectiveMaskShape: + msg = 'Mask was resized from %s to %s' + msg = msg % (str(mask.shape), str(effectiveMaskShape)) + raise RuntimeWarning(msg) + + def _loadMask(self): + """Open load mask dialog""" + dialog = qt.QFileDialog(self) + dialog.setWindowTitle("Load Mask") + dialog.setModal(1) + filters = [ + 'EDF (*.edf)', + 'TIFF (*.tif)', + 'NumPy binary file (*.npy)', + # Fit2D mask is displayed anyway fabio is here or not + # to show to the user that the option exists + 'Fit2D mask (*.msk)', + ] + dialog.setNameFilters(filters) + dialog.setFileMode(qt.QFileDialog.ExistingFile) + dialog.setDirectory(self.maskFileDir) + if not dialog.exec_(): + dialog.close() + return + + filename = dialog.selectedFiles()[0] + dialog.close() + + self.maskFileDir = os.path.dirname(filename) + try: + self.load(filename) + except RuntimeWarning as e: + message = e.args[0] + msg = qt.QMessageBox(self) + msg.setIcon(qt.QMessageBox.Warning) + msg.setText("Mask loaded but an operation was applied.\n" + message) + msg.exec_() + except Exception as e: + message = e.args[0] + msg = qt.QMessageBox(self) + msg.setIcon(qt.QMessageBox.Critical) + msg.setText("Cannot load mask from file. " + message) + msg.exec_() + + def _saveMask(self): + """Open Save mask dialog""" + dialog = qt.QFileDialog(self) + dialog.setWindowTitle("Save Mask") + dialog.setModal(1) + filters = [ + 'EDF (*.edf)', + 'TIFF (*.tif)', + 'NumPy binary file (*.npy)', + # Fit2D mask is displayed anyway fabio is here or not + # to show to the user that the option exists + 'Fit2D mask (*.msk)', + ] + dialog.setNameFilters(filters) + dialog.setFileMode(qt.QFileDialog.AnyFile) + dialog.setAcceptMode(qt.QFileDialog.AcceptSave) + dialog.setDirectory(self.maskFileDir) + if not dialog.exec_(): + dialog.close() + return + + # convert filter name to extension name with the . + extension = dialog.selectedNameFilter().split()[-1][2:-1] + filename = dialog.selectedFiles()[0] + dialog.close() + + if not filename.lower().endswith(extension): + filename += extension + + if os.path.exists(filename): + try: + os.remove(filename) + except IOError: + msg = qt.QMessageBox(self) + msg.setIcon(qt.QMessageBox.Critical) + msg.setText("Cannot save.\n" + "Input Output Error: %s" % (sys.exc_info()[1])) + msg.exec_() + return + + self.maskFileDir = os.path.dirname(filename) + try: + self.save(filename, extension[1:]) + except Exception as e: + msg = qt.QMessageBox(self) + msg.setIcon(qt.QMessageBox.Critical) + msg.setText("Cannot save file %s\n%s" % (filename, e.args[0])) + msg.exec_() + + def resetSelectionMask(self): + """Reset the mask""" + self._mask.reset(shape=self._data.shape[:2]) + self._mask.commit() + + def _plotDrawEvent(self, event): + """Handle draw events from the plot""" + if (self._drawingMode is None or + event['event'] not in ('drawingProgress', 'drawingFinished')): + return + + if not len(self._data): + return + + level = self.levelSpinBox.value() + + if (self._drawingMode == 'rectangle' and + event['event'] == 'drawingFinished'): + # Convert from plot to array coords + doMask = self._isMasking() + ox, oy = self._origin + sx, sy = self._scale + + height = int(abs(event['height'] / sy)) + width = int(abs(event['width'] / sx)) + + row = int((event['y'] - oy) / sy) + if sy < 0: + row -= height + + col = int((event['x'] - ox) / sx) + if sx < 0: + col -= width + + self._mask.updateRectangle( + level, + row=row, + col=col, + height=height, + width=width, + mask=doMask) + self._mask.commit() + + elif (self._drawingMode == 'polygon' and + event['event'] == 'drawingFinished'): + doMask = self._isMasking() + # Convert from plot to array coords + vertices = (event['points'] - self._origin) / self._scale + vertices = vertices.astype(numpy.int)[:, (1, 0)] # (row, col) + self._mask.updatePolygon(level, vertices, doMask) + self._mask.commit() + + elif self._drawingMode == 'pencil': + doMask = self._isMasking() + # convert from plot to array coords + col, row = (event['points'][-1] - self._origin) / self._scale + col, row = int(col), int(row) + brushSize = self.pencilSpinBox.value() + + if self._lastPencilPos != (row, col): + if self._lastPencilPos is not None: + # Draw the line + self._mask.updateLine( + level, + self._lastPencilPos[0], self._lastPencilPos[1], + row, col, + brushSize, + doMask) + + # Draw the very first, or last point + self._mask.updateDisk(level, row, col, brushSize / 2., doMask) + + if event['event'] == 'drawingFinished': + self._mask.commit() + self._lastPencilPos = None + else: + self._lastPencilPos = row, col + + def _loadRangeFromColormapTriggered(self): + """Set range from active image colormap range""" + activeImage = self.plot.getActiveImage() + if (isinstance(activeImage, items.ColormapMixIn) and + activeImage.getLegend() != self._maskName): + # Update thresholds according to colormap + colormap = activeImage.getColormap() + if colormap['autoscale']: + min_ = numpy.nanmin(activeImage.getData(copy=False)) + max_ = numpy.nanmax(activeImage.getData(copy=False)) + else: + min_, max_ = colormap['vmin'], colormap['vmax'] + self.minLineEdit.setText(str(min_)) + self.maxLineEdit.setText(str(max_)) + + +class MaskToolsDockWidget(BaseMaskToolsDockWidget): + """:class:`MaskToolsWidget` embedded in a QDockWidget. + + For integration in a :class:`PlotWindow`. + + :param parent: See :class:`QDockWidget` + :param plot: The PlotWidget this widget is operating on + :paran str name: The title of this widget + """ + def __init__(self, parent=None, plot=None, name='Mask'): + super(MaskToolsDockWidget, self).__init__(parent, name) + self.setWidget(MaskToolsWidget(plot=plot)) + self.widget().sigMaskChanged.connect(self._emitSigMaskChanged) diff --git a/silx/gui/plot/Plot.py b/silx/gui/plot/Plot.py new file mode 100644 index 0000000..fe0a7b8 --- /dev/null +++ b/silx/gui/plot/Plot.py @@ -0,0 +1,2925 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# ###########################################################################*/ +"""Plot API for 1D and 2D data. + +The :class:`Plot` implements the plot API initially provided in PyMca. + + +Colormap +-------- + +The :class:`Plot` uses a dictionary to describe a colormap. +This dictionary has the following keys: + +- 'name': str, name of the colormap. Available colormap are returned by + :meth:`Plot.getSupportedColormaps`. + At least 'gray', 'reversed gray', 'temperature', + 'red', 'green', 'blue' are supported. +- 'normalization': Either 'linear' or 'log' +- 'autoscale': bool, True to get bounds from the min and max of the + data, False to use [vmin, vmax] +- 'vmin': float, min value, ignored if autoscale is True +- 'vmax': float, max value, ignored if autoscale is True +- 'colors': optional, custom colormap. + Nx3 or Nx4 numpy array of RGB(A) colors, + either uint8 or float in [0, 1]. + If 'name' is None, then this array is used as the colormap. + + +Plot Events +----------- + +The Plot sends some event to the registered callback +(See :meth:`Plot.setCallback`). +Those events are sent as a dictionary with a key 'event' describing the kind +of event. + +Drawing events +.............. + +'drawingProgress' and 'drawingFinished' events are sent during drawing +interaction (See :meth:`Plot.setInteractiveMode`). + +- 'event': 'drawingProgress' or 'drawingFinished' +- 'parameters': dict of parameters used by the drawing mode. + It has the following keys: 'shape', 'label', 'color'. + See :meth:`Plot.setInteractiveMode`. +- 'points': Points (x, y) in data coordinates of the drawn shape. + For 'hline' and 'vline', it is the 2 points defining the line. + For 'line' and 'rectangle', it is the coordinates of the start + drawing point and the latest drawing point. + For 'polygon', it is the coordinates of all points of the shape. +- 'type': The type of drawing in 'line', 'hline', 'polygon', 'rectangle', + 'vline'. +- 'xdata' and 'ydata': X coords and Y coords of shape points in data + coordinates (as in 'points'). + +When the type is 'rectangle', the following additional keys are provided: + +- 'x' and 'y': The origin of the rectangle in data coordinates +- 'widht' and 'height': The size of the rectangle in data coordinates + + +Mouse events +............ + +'mouseMoved', 'mouseClicked' and 'mouseDoubleClicked' events are sent for +mouse events. + +They provide the following keys: + +- 'event': 'mouseMoved', 'mouseClicked' or 'mouseDoubleClicked' +- 'button': the mouse button that was pressed in 'left', 'middle', 'right' +- 'x' and 'y': The mouse position in data coordinates +- 'xpixel' and 'ypixel': The mouse position in pixels + + +Marker events +............. + +'hover', 'markerClicked', 'markerMoving' and 'markerMoved' events are +sent during interaction with markers. + +'hover' is sent when the mouse cursor is over a marker. +'markerClicker' is sent when the user click on a selectable marker. +'markerMoving' and 'markerMoved' are sent when a draggable marker is moved. + +They provide the following keys: + +- 'event': 'hover', 'markerClicked', 'markerMoving' or 'markerMoved' +- 'button': the mouse button that is pressed in 'left', 'middle', 'right' +- 'draggable': True if the marker is draggable, False otherwise +- 'label': The legend associated with the clicked image or curve +- 'selectable': True if the marker is selectable, False otherwise +- 'type': 'marker' +- 'x' and 'y': The mouse position in data coordinates +- 'xdata' and 'ydata': The marker position in data coordinates + +'markerClicked' and 'markerMoving' events have a 'xpixel' and a 'ypixel' +additional keys, that provide the mouse position in pixels. + + +Image and curve events +...................... + +'curveClicked' and 'imageClicked' events are sent when a selectable curve +or image is clicked. + +Both share the following keys: + +- 'event': 'curveClicked' or 'imageClicked' +- 'button': the mouse button that was pressed in 'left', 'middle', 'right' +- 'label': The legend associated with the clicked image or curve +- 'type': The type of item in 'curve', 'image' +- 'x' and 'y': The clicked position in data coordinates +- 'xpixel' and 'ypixel': The clicked position in pixels + +'curveClicked' events have a 'xdata' and a 'ydata' additional keys, that +provide the coordinates of the picked points of the curve. +There can be more than one point of the curve being picked, and if a line of +the curve is picked, only the first point of the line is included in the list. + +'imageClicked' have a 'col' and a 'row' additional keys, that provide +the column and row index in the image array that was clicked. + + +Limits changed events +..................... + +'limitsChanged' events are sent when the limits of the plot are changed. +This can results from user interaction or API calls. + +It provides the following keys: + +- 'event': 'limitsChanged' +- 'source': id of the widget that emitted this event. +- 'xdata': Range of X in graph coordinates: (xMin, xMax). +- 'ydata': Range of Y in graph coordinates: (yMin, yMax). +- 'y2data': Range of right axis in graph coordinates (y2Min, y2Max) or None. + +Plot state change events +........................ + +The following events are emitted when the plot is modified. +They provide the new state: + +- 'setGraphCursor' event with a 'state' key (bool) +- 'setGraphGrid' event with a 'which' key (str), see :meth:`setGraphGrid` +- 'setKeepDataAspectRatio' event with a 'state' key (bool) +- 'setXAxisAutoScale' event with a 'state' key (bool) +- 'setXAxisLogarithmic' event with a 'state' key (bool) +- 'setYAxisAutoScale' event with a 'state' key (bool) +- 'setYAxisInverted' event with a 'state' key (bool) +- 'setYAxisLogarithmic' event with a 'state' key (bool) + +A 'contentChanged' event is triggered when the content of the plot is updated. +It provides the following keys: + +- 'action': The change of the plot: 'add' or 'remove' +- 'kind': The kind of primitive changed: 'curve', 'image', 'item' or 'marker' +- 'legend': The legend of the primitive changed. + +'activeCurveChanged' and 'activeImageChanged' events with the following keys: + +- 'legend': Name (str) of the current active item or None if no active item. +- 'previous': Name (str) of the previous active item or None if no item was + active. It is the same as 'legend' if 'updated' == True +- 'updated': (bool) True if active item name did not changed, + but active item data or style was updated. + +'interactiveModeChanged' event with a 'source' key identifying the object +setting the interactive mode. +""" + +from __future__ import division + + +__authors__ = ["V.A. Sole", "T. Vincent"] +__license__ = "MIT" +__date__ = "16/02/2017" + + +from collections import OrderedDict, namedtuple +import itertools +import logging + +import numpy + +# Import matplotlib backend here to init matplotlib our way +from .backends.BackendMatplotlib import BackendMatplotlibQt + +try: + from matplotlib import cm as matplotlib_cm +except ImportError: + matplotlib_cm = None + +from . import Colors +from . import PlotInteraction +from . import PlotEvents +from . import _utils + +from . import items + + +_logger = logging.getLogger(__name__) + + +_COLORDICT = Colors.COLORDICT +_COLORLIST = [_COLORDICT['black'], + _COLORDICT['blue'], + _COLORDICT['red'], + _COLORDICT['green'], + _COLORDICT['pink'], + _COLORDICT['yellow'], + _COLORDICT['brown'], + _COLORDICT['cyan'], + _COLORDICT['magenta'], + _COLORDICT['orange'], + _COLORDICT['violet'], + # _COLORDICT['bluegreen'], + _COLORDICT['grey'], + _COLORDICT['darkBlue'], + _COLORDICT['darkRed'], + _COLORDICT['darkGreen'], + _COLORDICT['darkCyan'], + _COLORDICT['darkMagenta'], + _COLORDICT['darkYellow'], + _COLORDICT['darkBrown']] + + +""" +Object returned when requesting the data range. +""" +_PlotDataRange = namedtuple('PlotDataRange', + ['x', 'y', 'yright']) + + +class Plot(object): + """This class implements the plot API initially provided in PyMca. + + Supported backends: + + - 'matplotlib' and 'mpl': Matplotlib with Qt. + - 'opengl' and 'gl': OpenGL backend (requires PyOpenGL and OpenGL >= 2.1) + - 'none': No backend, to run headless for testing purpose. + + :param parent: The parent widget of the plot (Default: None) + :param backend: The backend to use. A str in: + 'matplotlib', 'mpl', 'opengl', 'gl', 'none' + or a :class:`BackendBase.BackendBase` class + """ + + DEFAULT_BACKEND = 'matplotlib' + """Class attribute setting the default backend for all instances.""" + + colorList = _COLORLIST + colorDict = _COLORDICT + + def __init__(self, parent=None, backend=None): + self._autoreplot = False + self._dirty = False + self._cursorInPlot = False + + if backend is None: + backend = self.DEFAULT_BACKEND + + if hasattr(backend, "__call__"): + self._backend = backend(self, parent) + + elif hasattr(backend, "lower"): + lowerCaseString = backend.lower() + if lowerCaseString in ("matplotlib", "mpl"): + backendClass = BackendMatplotlibQt + elif lowerCaseString in ('gl', 'opengl'): + from .backends.BackendOpenGL import BackendOpenGL + backendClass = BackendOpenGL + elif lowerCaseString == 'none': + from .backends.BackendBase import BackendBase as backendClass + else: + raise ValueError("Backend not supported %s" % backend) + self._backend = backendClass(self, parent) + + else: + raise ValueError("Backend not supported %s" % str(backend)) + + super(Plot, self).__init__() + + self.setCallback() # set _callback + + # Items handling + self._content = OrderedDict() + self._contentToUpdate = set() + + self._dataRange = None + + # line types + self._styleList = ['-', '--', '-.', ':'] + self._colorIndex = 0 + self._styleIndex = 0 + + self._activeCurveHandling = True + self._activeCurveColor = "#000000" + self._activeLegend = {'curve': None, 'image': None, + 'scatter': None} + + # default properties + self._cursorConfiguration = None + + self._logY = False + self._logX = False + self._xAutoScale = True + self._yAutoScale = True + self._grid = None + + # Store default labels provided to setGraph[X|Y]Label + self._defaultLabels = {'x': '', 'y': '', 'yright': ''} + # Store currently displayed labels + # Current label can differ from input one with active curve handling + self._currentLabels = {'x': '', 'y': '', 'yright': ''} + + self._graphTitle = '' + + self.setGraphTitle() + self.setGraphXLabel() + self.setGraphYLabel() + self.setGraphYLabel('', axis='right') + + self.setDefaultColormap() # Init default colormap + + self.setDefaultPlotPoints(False) + self.setDefaultPlotLines(True) + + self._eventHandler = PlotInteraction.PlotInteraction(self) + self._eventHandler.setInteractiveMode('zoom', color=(0., 0., 0., 1.)) + + self._pressedButtons = [] # Currently pressed mouse buttons + + self._defaultDataMargins = (0., 0., 0., 0.) + + # Only activate autoreplot at the end + # This avoids errors when loaded in Qt designer + self._dirty = False + self._autoreplot = True + + def _getDirtyPlot(self): + """Return the plot dirty flag. + + If False, the plot has not changed since last replot. + If True, the full plot need to be redrawn. + If 'overlay', only the overlay has changed since last replot. + + It can be accessed by backend to check the dirty state. + + :return: False, True, 'overlay' + """ + return self._dirty + + def _setDirtyPlot(self, overlayOnly=False): + """Mark the plot as needing redraw + + :param bool overlayOnly: True to redraw only the overlay, + False to redraw everything + """ + wasDirty = self._dirty + + if not self._dirty and overlayOnly: + self._dirty = 'overlay' + else: + self._dirty = True + + if self._autoreplot and not wasDirty: + self._backend.postRedisplay() + + def _invalidateDataRange(self): + """ + Notifies this Plot instance that the range has changed and will have + to be recomputed. + """ + self._dataRange = None + + def _updateDataRange(self): + """ + Recomputes the range of the data displayed on this Plot. + """ + xMin = yMinLeft = yMinRight = float('nan') + xMax = yMaxLeft = yMaxRight = float('nan') + + for item in self._content.values(): + if item.isVisible(): + bounds = item.getBounds() + if bounds is not None: + xMin = numpy.nanmin([xMin, bounds[0]]) + xMax = numpy.nanmax([xMax, bounds[1]]) + # Take care of right axis + if (isinstance(item, items.YAxisMixIn) and + item.getYAxis() == 'right'): + yMinRight = numpy.nanmin([yMinRight, bounds[2]]) + yMaxRight = numpy.nanmax([yMaxRight, bounds[3]]) + else: + yMinLeft = numpy.nanmin([yMinLeft, bounds[2]]) + yMaxLeft = numpy.nanmax([yMaxLeft, bounds[3]]) + + def lGetRange(x, y): + return None if numpy.isnan(x) and numpy.isnan(y) else (x, y) + xRange = lGetRange(xMin, xMax) + yLeftRange = lGetRange(yMinLeft, yMaxLeft) + yRightRange = lGetRange(yMinRight, yMaxRight) + + self._dataRange = _PlotDataRange(x=xRange, + y=yLeftRange, + yright=yRightRange) + + def getDataRange(self): + """ + Returns this Plot's data range. + + :return: a namedtuple with the following members : + x, y (left y axis), yright. Each member is a tuple (min, max) + or None if no data is associated with the axis. + :rtype: namedtuple + """ + if self._dataRange is None: + self._updateDataRange() + return self._dataRange + + # Content management + + @staticmethod + def _itemKey(item): + """Build the key of given :class:`Item` in the plot + + :param Item item: The item to make the key from + :return: (legend, kind) + :rtype: (str, str) + """ + if isinstance(item, items.Curve): + kind = 'curve' + elif isinstance(item, items.ImageBase): + kind = 'image' + elif isinstance(item, items.Scatter): + kind = 'scatter' + elif isinstance(item, (items.Marker, + items.XMarker, items.YMarker)): + kind = 'marker' + elif isinstance(item, items.Shape): + kind = 'item' + elif isinstance(item, items.Histogram): + kind = 'histogram' + else: + raise ValueError('Unsupported item type %s' % type(item)) + + return item.getLegend(), kind + + def _add(self, item): + """Add the given :class:`Item` to the plot. + + :param Item item: The item to append to the plot content + """ + key = self._itemKey(item) + if key in self._content: + raise RuntimeError('Item already in the plot') + + # Add item to plot + self._content[key] = item + item._setPlot(self) + if item.isVisible(): + self._itemRequiresUpdate(item) + if isinstance(item, (items.Curve, items.ImageBase)): + self._invalidateDataRange() # TODO handle this automatically + + def _remove(self, item): + """Remove the given :class:`Item` from the plot. + + :param Item item: The item to remove from the plot content + """ + key = self._itemKey(item) + if key not in self._content: + raise RuntimeError('Item not in the plot') + + # Remove item from plot + self._content.pop(key) + self._contentToUpdate.discard(item) + if item.isVisible(): + self._setDirtyPlot(overlayOnly=item.isOverlay()) + if item.getBounds() is not None: + self._invalidateDataRange() + item._removeBackendRenderer(self._backend) + item._setPlot(None) + + def _itemRequiresUpdate(self, item): + """Called by items in the plot for asynchronous update + + :param Item item: The item that required update + """ + assert item.getPlot() == self + self._contentToUpdate.add(item) + self._setDirtyPlot(overlayOnly=item.isOverlay()) + + # Add + + # add * input arguments management: + # If an arg is set, then use it. + # Else: + # If a curve with the same legend exists, then use its arg value + # Else, use a default value. + # Store used value. + # This value is used when curve is updated either internally or by user. + + def addCurve(self, x, y, legend=None, info=None, + replace=False, replot=None, + color=None, symbol=None, + linewidth=None, linestyle=None, + xlabel=None, ylabel=None, yaxis=None, + xerror=None, yerror=None, z=None, selectable=None, + fill=None, resetzoom=True, + histogram=None, copy=True, **kw): + """Add a 1D curve given by x an y to the graph. + + Curves are uniquely identified by their legend. + To add multiple curves, call :meth:`addCurve` multiple times with + different legend argument. + To replace an existing curve, call :meth:`addCurve` with the + existing curve legend. + If you want to display the curve values as an histogram see the + histogram parameter or :meth:`addHistogram`. + + When curve parameters are not provided, if a curve with the + same legend is displayed in the plot, its parameters are used. + + :param numpy.ndarray x: The data corresponding to the x coordinates. + If you attempt to plot an histogram you can set edges values in x. + In this case len(x) = len(y) + 1 + :param numpy.ndarray y: The data corresponding to the y coordinates + :param str legend: The legend to be associated to the curve (or None) + :param info: User-defined information associated to the curve + :param bool replace: True (the default) to delete already existing + curves + :param color: color(s) to be used + :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or + one of the predefined color names defined in Colors.py + :param str symbol: Symbol to be drawn at each (x, y) position:: + + - 'o' circle + - '.' point + - ',' pixel + - '+' cross + - 'x' x-cross + - 'd' diamond + - 's' square + - None (the default) to use default symbol + + :param float linewidth: The width of the curve in pixels (Default: 1). + :param str linestyle: Type of line:: + + - ' ' no line + - '-' solid line + - '--' dashed line + - '-.' dash-dot line + - ':' dotted line + - None (the default) to use default line style + + :param str xlabel: Label to show on the X axis when the curve is active + or None to keep default axis label. + :param str ylabel: Label to show on the Y axis when the curve is active + or None to keep default axis label. + :param str yaxis: The Y axis this curve is attached to. + Either 'left' (the default) or 'right' + :param xerror: Values with the uncertainties on the x values + :type xerror: A float, or a numpy.ndarray of float32. + If it is an array, it can either be a 1D array of + same length as the data or a 2D array with 2 rows + of same length as the data: row 0 for positive errors, + row 1 for negative errors. + :param yerror: Values with the uncertainties on the y values + :type yerror: A float, or a numpy.ndarray of float32. See xerror. + :param int z: Layer on which to draw the curve (default: 1) + This allows to control the overlay. + :param bool selectable: Indicate if the curve can be selected. + (Default: True) + :param bool fill: True to fill the curve, False otherwise (default). + :param bool resetzoom: True (the default) to reset the zoom. + :param str histogram: if not None then the curve will be draw as an + histogram. The step for each values of the curve can be set to the + left, center or right of the original x curve values. + If histogram is not None and len(x) == len(y)+1 then x is directly + take as edges of the histogram. + Type of histogram:: + + - None (default) + - 'left' + - 'right' + - 'center' + :param bool copy: True make a copy of the data (default), + False to use provided arrays. + :returns: The key string identify this curve + """ + # Deprecation warnings + if replot is not None: + _logger.warning( + 'addCurve deprecated replot argument, use resetzoom instead') + resetzoom = replot and resetzoom + + if kw: + _logger.warning('addCurve: deprecated extra arguments') + + # This is an histogram, use addHistogram + if histogram is not None: + histoLegend = self.addHistogram(histogram=y, + edges=x, + legend=legend, + color=color, + fill=fill, + align=histogram, + copy=copy) + histo = self.getHistogram(histoLegend) + + histo.setInfo(info) + if linewidth is not None: + histo.setLineWidth(linewidth) + if linestyle is not None: + histo.setLineStyle(linestyle) + if xlabel is not None: + _logger.warning( + 'addCurve: Histogram does not support xlabel argument') + if ylabel is not None: + _logger.warning( + 'addCurve: Histogram does not support ylabel argument') + if yaxis is not None: + histo.setYAxis(yaxis) + if z is not None: + histo.setZValue(z) + if selectable is not None: + _logger.warning( + 'addCurve: Histogram does not support selectable argument') + + return + + legend = 'Unnamed curve 1.1' if legend is None else str(legend) + + # Check if curve was previously active + wasActive = self.getActiveCurve(just_legend=True) == legend + + # Create/Update curve object + curve = self.getCurve(legend) + if curve is None: + # No previous curve, create a default one and add it to the plot + curve = items.Curve() if histogram is None else items.Histogram() + curve._setLegend(legend) + # Set default color, linestyle and symbol + default_color, default_linestyle = self._getColorAndStyle() + curve.setColor(default_color) + curve.setLineStyle(default_linestyle) + curve.setSymbol(self._defaultPlotPoints) + self._add(curve) + + # Override previous/default values with provided ones + curve.setInfo(info) + if color is not None: + curve.setColor(color) + if symbol is not None: + curve.setSymbol(symbol) + if linewidth is not None: + curve.setLineWidth(linewidth) + if linestyle is not None: + curve.setLineStyle(linestyle) + if xlabel is not None: + curve._setXLabel(xlabel) + if ylabel is not None: + curve._setYLabel(ylabel) + if yaxis is not None: + curve.setYAxis(yaxis) + if z is not None: + curve.setZValue(z) + if selectable is not None: + curve._setSelectable(selectable) + if fill is not None: + curve.setFill(fill) + + # Set curve data + # If errors not provided, reuse previous ones + # TODO: Issue if size of data change but not that of errors + if xerror is None: + xerror = curve.getXErrorData(copy=False) + if yerror is None: + yerror = curve.getYErrorData(copy=False) + + curve.setData(x, y, xerror, yerror, copy=copy) + + if replace: # Then remove all other curves + for c in self.getAllCurves(withhidden=True): + if c is not curve: + self._remove(c) + + self.notify( + 'contentChanged', action='add', kind='curve', legend=legend) + + if wasActive: + self.setActiveCurve(curve.getLegend()) + + if resetzoom: + # We ask for a zoom reset in order to handle the plot scaling + # if the user does not want that, autoscale of the different + # axes has to be set to off. + self.resetZoom() + + return legend + + def addHistogram(self, + histogram, + edges, + legend=None, + color=None, + fill=None, + align='center', + resetzoom=True, + copy=True): + """Add an histogram to the graph. + + This is NOT computing the histogram, this method takes as parameter + already computed histogram values. + + Histogram are uniquely identified by their legend. + To add multiple histograms, call :meth:`addHistogram` multiple times + with different legend argument. + + When histogram parameters are not provided, if an histogram with the + same legend is displayed in the plot, its parameters are used. + + :param numpy.ndarray histogram: The values of the histogram. + :param numpy.ndarray edges: + The bin edges of the histogram. + If histogram and edges have the same length, the bin edges + are computed according to the align parameter. + :param str legend: + The legend to be associated to the histogram (or None) + :param color: color to be used + :type color: str ("#RRGGBB") or RGB unsigned byte array or + one of the predefined color names defined in Colors.py + :param bool fill: True to fill the curve, False otherwise (default). + :param str align: + In case histogram values and edges have the same length N, + the N+1 bin edges are computed according to the alignment in: + 'center' (default), 'left', 'right'. + :param bool resetzoom: True (the default) to reset the zoom. + :param bool copy: True make a copy of the data (default), + False to use provided arrays. + :returns: The key string identify this histogram + """ + legend = 'Unnamed histogram' if legend is None else str(legend) + + # Create/Update histogram object + histo = self.getHistogram(legend) + if histo is None: + # No previous histogram, create a default one and + # add it to the plot + histo = items.Histogram() + histo._setLegend(legend) + histo.setColor(self._getColorAndStyle()[0]) + self._add(histo) + + # Override previous/default values with provided ones + if color is not None: + histo.setColor(color) + if fill is not None: + histo.setFill(fill) + + # Set histogram data + histo.setData(histogram, edges, align=align, copy=copy) + + self.notify( + 'contentChanged', action='add', kind='histogram', legend=legend) + + if resetzoom: + # We ask for a zoom reset in order to handle the plot scaling + # if the user does not want that, autoscale of the different + # axes has to be set to off. + self.resetZoom() + + return legend + + def addImage(self, data, legend=None, info=None, + replace=True, replot=None, + xScale=None, yScale=None, z=None, + selectable=None, draggable=None, + colormap=None, pixmap=None, + xlabel=None, ylabel=None, + origin=None, scale=None, + resetzoom=True, copy=True, **kw): + """Add a 2D dataset or an image to the plot. + + It displays either an array of data using a colormap or a RGB(A) image. + + Images are uniquely identified by their legend. + To add multiple images, call :meth:`addImage` multiple times with + different legend argument. + To replace/update an existing image, call :meth:`addImage` with the + existing image legend. + + When image parameters are not provided, if an image with the + same legend is displayed in the plot, its parameters are used. + + :param numpy.ndarray data: (nrows, ncolumns) data or + (nrows, ncolumns, RGBA) ubyte array + :param str legend: The legend to be associated to the image (or None) + :param info: User-defined information associated to the image + :param bool replace: True (default) to delete already existing images + :param int z: Layer on which to draw the image (default: 0) + This allows to control the overlay. + :param bool selectable: Indicate if the image can be selected. + (default: False) + :param bool draggable: Indicate if the image can be moved. + (default: False) + :param dict colormap: Description of the colormap to use (or None) + This is ignored if data is a RGB(A) image. + See :mod:`Plot` for the documentation + of the colormap dict. + :param pixmap: Pixmap representation of the data (if any) + :type pixmap: (nrows, ncolumns, RGBA) ubyte array or None (default) + :param str xlabel: X axis label to show when this curve is active, + or None to keep default axis label. + :param str ylabel: Y axis label to show when this curve is active, + or None to keep default axis label. + :param origin: (origin X, origin Y) of the data. + It is possible to pass a single float if both + coordinates are equal. + Default: (0., 0.) + :type origin: float or 2-tuple of float + :param scale: (scale X, scale Y) of the data. + It is possible to pass a single float if both + coordinates are equal. + Default: (1., 1.) + :type scale: float or 2-tuple of float + :param bool resetzoom: True (the default) to reset the zoom. + :param bool copy: True make a copy of the data (default), + False to use provided arrays. + :returns: The key string identify this image + """ + # Deprecation warnings + if xScale is not None or yScale is not None: + _logger.warning( + 'addImage deprecated xScale and yScale arguments,' + 'use origin, scale arguments instead.') + if origin is None and scale is None: + origin = xScale[0], yScale[0] + scale = xScale[1], yScale[1] + else: + _logger.warning( + 'addCurve: xScale, yScale and origin, scale arguments' + ' are conflicting. xScale and yScale are ignored.' + ' Use only origin, scale arguments.') + + if replot is not None: + _logger.warning( + 'addImage deprecated replot argument, use resetzoom instead') + resetzoom = replot and resetzoom + + if kw: + _logger.warning('addImage: deprecated extra arguments') + + legend = "Unnamed Image 1.1" if legend is None else str(legend) + + # Check if image was previously active + wasActive = self.getActiveImage(just_legend=True) == legend + + data = numpy.array(data, copy=False) + assert data.ndim in (2, 3) + + image = self.getImage(legend) + if image is not None and image.getData(copy=False).ndim != data.ndim: + # Update a data image with RGBA image or the other way around: + # Remove previous image + # In this case, we don't retrieve defaults from the previous image + self._remove(image) + image = None + + if image is None: + # No previous image, create a default one and add it to the plot + if data.ndim == 2: + image = items.ImageData() + image.setColormap(self.getDefaultColormap()) + else: + image = items.ImageRgba() + image._setLegend(legend) + self._add(image) + + # Override previous/default values with provided ones + image.setInfo(info) + if origin is not None: + image.setOrigin(origin) + if scale is not None: + image.setScale(scale) + if z is not None: + image.setZValue(z) + if selectable is not None: + image._setSelectable(selectable) + if draggable is not None: + image._setDraggable(draggable) + if colormap is not None and isinstance(image, items.ColormapMixIn): + image.setColormap(colormap) + if xlabel is not None: + image._setXLabel(xlabel) + if ylabel is not None: + image._setYLabel(ylabel) + + if data.ndim == 2: + image.setData(data, alternative=pixmap, copy=copy) + else: # RGB(A) image + if pixmap is not None: + _logger.warning( + 'addImage: pixmap argument ignored when data is RGB(A)') + image.setData(data, copy=copy) + + if replace: + for img in self.getAllImages(): + if img is not image: + self._remove(img) + + if len(self.getAllImages()) == 1 or wasActive: + self.setActiveImage(legend) + + self.notify( + 'contentChanged', action='add', kind='image', legend=legend) + + if resetzoom: + # We ask for a zoom reset in order to handle the plot scaling + # if the user does not want that, autoscale of the different + # axes has to be set to off. + self.resetZoom() + + return legend + + def addScatter(self, x, y, value, legend=None, colormap=None, + info=None, symbol=None, xerror=None, yerror=None, + z=None, copy=True): + """Add a (x, y, value) scatter to the graph. + + Scatters are uniquely identified by their legend. + To add multiple scatters, call :meth:`addScatter` multiple times with + different legend argument. + To replace/update an existing scatter, call :meth:`addScatter` with the + existing scatter legend. + + When scatter parameters are not provided, if a scatter with the + same legend is displayed in the plot, its parameters are used. + + :param numpy.ndarray x: The data corresponding to the x coordinates. + :param numpy.ndarray y: The data corresponding to the y coordinates + :param numpy.ndarray value: The data value associated with each point + :param str legend: The legend to be associated to the scatter (or None) + :param dict colormap: The colormap to be used for the scatter (or None) + See :mod:`Plot` for the documentation + of the colormap dict. + :param info: User-defined information associated to the curve + :param str symbol: Symbol to be drawn at each (x, y) position:: + + - 'o' circle + - '.' point + - ',' pixel + - '+' cross + - 'x' x-cross + - 'd' diamond + - 's' square + - None (the default) to use default symbol + + :param xerror: Values with the uncertainties on the x values + :type xerror: A float, or a numpy.ndarray of float32. + If it is an array, it can either be a 1D array of + same length as the data or a 2D array with 2 rows + of same length as the data: row 0 for positive errors, + row 1 for negative errors. + :param yerror: Values with the uncertainties on the y values + :type yerror: A float, or a numpy.ndarray of float32. See xerror. + :param int z: Layer on which to draw the scatter (default: 1) + This allows to control the overlay. + + :param bool copy: True make a copy of the data (default), + False to use provided arrays. + :returns: The key string identify this scatter + """ + legend = 'Unnamed scatter 1.1' if legend is None else str(legend) + + # Check if scatter was previously active + wasActive = self._getActiveItem(kind='scatter', + just_legend=True) == legend + + # Create/Update curve object + scatter = self._getItem(kind='scatter', legend=legend) + if scatter is None: + # No previous scatter, create a default one and add it to the plot + scatter = items.Scatter() + scatter._setLegend(legend) + scatter.setColormap(self.getDefaultColormap()) + self._add(scatter) + + # Override previous/default values with provided ones + scatter.setInfo(info) + if symbol is not None: + scatter.setSymbol(symbol) + if z is not None: + scatter.setZValue(z) + if colormap is not None: + scatter.setColormap(colormap) + + # Set scatter data + # If errors not provided, reuse previous ones + if xerror is None: + xerror = scatter.getXErrorData(copy=False) + if xerror is not None and len(xerror) != len(x): + xerror = None + if yerror is None: + yerror = scatter.getYErrorData(copy=False) + if yerror is not None and len(yerror) != len(y): + yerror = None + + scatter.setData(x, y, value, xerror, yerror, copy=copy) + + self.notify( + 'contentChanged', action='add', kind='scatter', legend=legend) + + if len(self._getItems(kind="scatter")) == 1 or wasActive: + self._setActiveItem('scatter', scatter.getLegend()) + + return legend + + def addItem(self, xdata, ydata, legend=None, info=None, + replace=False, + shape="polygon", color='black', fill=True, + overlay=False, z=None, **kw): + """Add an item (i.e. a shape) to the plot. + + Items are uniquely identified by their legend. + To add multiple items, call :meth:`addItem` multiple times with + different legend argument. + To replace/update an existing item, call :meth:`addItem` with the + existing item legend. + + :param numpy.ndarray xdata: The X coords of the points of the shape + :param numpy.ndarray ydata: The Y coords of the points of the shape + :param str legend: The legend to be associated to the item + :param info: User-defined information associated to the item + :param bool replace: True (default) to delete already existing images + :param str shape: Type of item to be drawn in + hline, polygon (the default), rectangle, vline, + polylines + :param str color: Color of the item, e.g., 'blue', 'b', '#FF0000' + (Default: 'black') + :param bool fill: True (the default) to fill the shape + :param bool overlay: True if item is an overlay (Default: False). + This allows for rendering optimization if this + item is changed often. + :param int z: Layer on which to draw the item (default: 2) + :returns: The key string identify this item + """ + # expected to receive the same parameters as the signal + + if kw: + _logger.warning('addItem deprecated parameters: %s', str(kw)) + + legend = "Unnamed Item 1.1" if legend is None else str(legend) + + z = int(z) if z is not None else 2 + + if replace: + self.remove(kind='item') + else: + self.remove(legend, kind='item') + + item = items.Shape(shape) + item._setLegend(legend) + item.setInfo(info) + item.setColor(color) + item.setFill(fill) + item.setOverlay(overlay) + item.setZValue(z) + item.setPoints(numpy.array((xdata, ydata)).T) + + self._add(item) + + self.notify('contentChanged', action='add', kind='item', legend=legend) + + return legend + + def addXMarker(self, x, legend=None, + text=None, + color=None, + selectable=False, + draggable=False, + constraint=None, + **kw): + """Add a vertical line marker to the plot. + + Markers are uniquely identified by their legend. + As opposed to curves, images and items, two calls to + :meth:`addXMarker` without legend argument adds two markers with + different identifying legends. + + :param float x: Position of the marker on the X axis in data + coordinates + :param str legend: Legend associated to the marker to identify it + :param str text: Text to display on the marker. + :param str color: Color of the marker, e.g., 'blue', 'b', '#FF0000' + (Default: 'black') + :param bool selectable: Indicate if the marker can be selected. + (default: False) + :param bool draggable: Indicate if the marker can be moved. + (default: False) + :param constraint: A function filtering marker displacement by + dragging operations or None for no filter. + This function is called each time a marker is + moved. + This parameter is only used if draggable is True. + :type constraint: None or a callable that takes the coordinates of + the current cursor position in the plot as input + and that returns the filtered coordinates. + :return: The key string identify this marker + """ + if kw: + _logger.warning( + 'addXMarker deprecated extra parameters: %s', str(kw)) + + return self._addMarker(x=x, y=None, legend=legend, + text=text, color=color, + selectable=selectable, draggable=draggable, + symbol=None, constraint=constraint) + + def addYMarker(self, y, + legend=None, + text=None, + color=None, + selectable=False, + draggable=False, + constraint=None, + **kw): + """Add a horizontal line marker to the plot. + + Markers are uniquely identified by their legend. + As opposed to curves, images and items, two calls to + :meth:`addYMarker` without legend argument adds two markers with + different identifying legends. + + :param float y: Position of the marker on the Y axis in data + coordinates + :param str legend: Legend associated to the marker to identify it + :param str text: Text to display next to the marker. + :param str color: Color of the marker, e.g., 'blue', 'b', '#FF0000' + (Default: 'black') + :param bool selectable: Indicate if the marker can be selected. + (default: False) + :param bool draggable: Indicate if the marker can be moved. + (default: False) + :param constraint: A function filtering marker displacement by + dragging operations or None for no filter. + This function is called each time a marker is + moved. + This parameter is only used if draggable is True. + :type constraint: None or a callable that takes the coordinates of + the current cursor position in the plot as input + and that returns the filtered coordinates. + :return: The key string identify this marker + """ + if kw: + _logger.warning( + 'addYMarker deprecated extra parameters: %s', str(kw)) + + return self._addMarker(x=None, y=y, legend=legend, + text=text, color=color, + selectable=selectable, draggable=draggable, + symbol=None, constraint=constraint) + + def addMarker(self, x, y, legend=None, + text=None, + color=None, + selectable=False, + draggable=False, + symbol='+', + constraint=None, + **kw): + """Add a point marker to the plot. + + Markers are uniquely identified by their legend. + As opposed to curves, images and items, two calls to + :meth:`addMarker` without legend argument adds two markers with + different identifying legends. + + :param float x: Position of the marker on the X axis in data + coordinates + :param float y: Position of the marker on the Y axis in data + coordinates + :param str legend: Legend associated to the marker to identify it + :param str text: Text to display next to the marker + :param str color: Color of the marker, e.g., 'blue', 'b', '#FF0000' + (Default: 'black') + :param bool selectable: Indicate if the marker can be selected. + (default: False) + :param bool draggable: Indicate if the marker can be moved. + (default: False) + :param str symbol: Symbol representing the marker in:: + + - 'o' circle + - '.' point + - ',' pixel + - '+' cross (the default) + - 'x' x-cross + - 'd' diamond + - 's' square + + :param constraint: A function filtering marker displacement by + dragging operations or None for no filter. + This function is called each time a marker is + moved. + This parameter is only used if draggable is True. + :type constraint: None or a callable that takes the coordinates of + the current cursor position in the plot as input + and that returns the filtered coordinates. + :return: The key string identify this marker + """ + if kw: + _logger.warning( + 'addMarker deprecated extra parameters: %s', str(kw)) + + if x is None: + xmin, xmax = self.getGraphXLimits() + x = 0.5 * (xmax + xmin) + + if y is None: + ymin, ymax = self.getGraphYLimits() + y = 0.5 * (ymax + ymin) + + return self._addMarker(x=x, y=y, legend=legend, + text=text, color=color, + selectable=selectable, draggable=draggable, + symbol=symbol, constraint=constraint) + + def _addMarker(self, x, y, legend, + text, color, + selectable, draggable, + symbol, constraint): + """Common method for adding point, vline and hline marker. + + See :meth:`addMarker` for argument documentation. + """ + assert (x, y) != (None, None) + + if legend is None: # Find an unused legend + markerLegends = self._getAllMarkers(just_legend=True) + for index in itertools.count(): + legend = "Unnamed Marker %d" % index + if legend not in markerLegends: + break # Keep this legend + legend = str(legend) + + if x is None: + markerClass = items.YMarker + elif y is None: + markerClass = items.XMarker + else: + markerClass = items.Marker + + # Create/Update marker object + marker = self._getMarker(legend) + if marker is not None and not isinstance(marker, markerClass): + _logger.warning('Adding marker with same legend' + ' but different type replaces it') + self._remove(marker) + marker = None + + if marker is None: + # No previous marker, create one + marker = markerClass() + marker._setLegend(legend) + self._add(marker) + + if text is not None: + marker.setText(text) + if color is not None: + marker.setColor(color) + if selectable is not None: + marker._setSelectable(selectable) + if draggable is not None: + marker._setDraggable(draggable) + if symbol is not None: + marker.setSymbol(symbol) + + # TODO to improve, but this ensure constraint is applied + marker.setPosition(x, y) + if constraint is not None: + marker._setConstraint(constraint) + marker.setPosition(x, y) + + self.notify( + 'contentChanged', action='add', kind='marker', legend=legend) + + return legend + + # Hide + + def isCurveHidden(self, legend): + """Returns True if the curve associated to legend is hidden, else False + + :param str legend: The legend key identifying the curve + :return: True if the associated curve is hidden, False otherwise + """ + curve = self._getItem('curve', legend) + return curve is not None and not curve.isVisible() + + def hideCurve(self, legend, flag=True, replot=None): + """Show/Hide the curve associated to legend. + + Even when hidden, the curve is kept in the list of curves. + + :param str legend: The legend associated to the curve to be hidden + :param bool flag: True (default) to hide the curve, False to show it + """ + if replot is not None: + _logger.warning('hideCurve deprecated replot parameter') + + curve = self._getItem('curve', legend) + if curve is None: + _logger.warning('Curve not in plot: %s', legend) + return + + isVisible = not flag + if isVisible != curve.isVisible(): + curve.setVisible(isVisible) + + # Remove + + ITEM_KINDS = 'curve', 'image', 'scatter', 'item', 'marker', 'histogram' + + def remove(self, legend=None, kind=ITEM_KINDS): + """Remove one or all element(s) of the given legend and kind. + + Examples: + + - ``remove()`` clears the plot + - ``remove(kind='curve')`` removes all curves from the plot + - ``remove('myCurve', kind='curve')`` removes the curve with + legend 'myCurve' from the plot. + - ``remove('myImage, kind='image')`` removes the image with + legend 'myImage' from the plot. + - ``remove('myImage')`` removes elements (for instance curve, image, + item and marker) with legend 'myImage'. + + :param str legend: The legend associated to the element to remove, + or None to remove + :param kind: The kind of elements to remove from the plot. + In: 'all', 'curve', 'image', 'item', 'marker'. + By default, it removes all kind of elements. + :type kind: str or tuple of str to specify multiple kinds. + """ + if kind is 'all': # Replace all by tuple of all kinds + kind = self.ITEM_KINDS + + if kind in self.ITEM_KINDS: # Kind is a str, make it a tuple + kind = (kind,) + + for aKind in kind: + assert aKind in self.ITEM_KINDS + + if legend is None: # This is a clear + # Clear each given kind + for aKind in kind: + for legend in self._getItems( + kind=aKind, just_legend=True, withhidden=True): + self.remove(legend=legend, kind=aKind) + + else: # This is removing a single element + # Remove each given kind + for aKind in kind: + item = self._getItem(aKind, legend) + if item is not None: + if aKind in ('curve', 'image'): + if self._getActiveItem(aKind) == item: + # Reset active item + self._setActiveItem(aKind, None) + + self._remove(item) + + if (aKind == 'curve' and + not self.getAllCurves(just_legend=True, + withhidden=True)): + self._colorIndex = 0 + self._styleIndex = 0 + + self.notify('contentChanged', action='remove', + kind=aKind, legend=legend) + + def removeCurve(self, legend): + """Remove the curve associated to legend from the graph. + + :param str legend: The legend associated to the curve to be deleted + """ + if legend is None: + return + self.remove(legend, kind='curve') + + def removeImage(self, legend): + """Remove the image associated to legend from the graph. + + :param str legend: The legend associated to the image to be deleted + """ + if legend is None: + return + self.remove(legend, kind='image') + + def removeItem(self, legend): + """Remove the item associated to legend from the graph. + + :param str legend: The legend associated to the item to be deleted + """ + if legend is None: + return + self.remove(legend, kind='item') + + def removeMarker(self, legend): + """Remove the marker associated to legend from the graph. + + :param str legend: The legend associated to the marker to be deleted + """ + if legend is None: + return + self.remove(legend, kind='marker') + + # Clear + + def clear(self): + """Remove everything from the plot.""" + self.remove() + + def clearCurves(self): + """Remove all the curves from the plot.""" + self.remove(kind='curve') + + def clearImages(self): + """Remove all the images from the plot.""" + self.remove(kind='image') + + def clearItems(self): + """Remove all the items from the plot. """ + self.remove(kind='item') + + def clearMarkers(self): + """Remove all the markers from the plot.""" + self.remove(kind='marker') + + # Interaction + + def getGraphCursor(self): + """Returns the state of the crosshair cursor. + + See :meth:`setGraphCursor`. + + :return: None if the crosshair cursor is not active, + else a tuple (color, linewidth, linestyle). + """ + return self._cursorConfiguration + + def setGraphCursor(self, flag=False, color='black', + linewidth=1, linestyle='-'): + """Toggle the display of a crosshair cursor and set its attributes. + + :param bool flag: Toggle the display of a crosshair cursor. + The crosshair cursor is hidden by default. + :param color: The color to use for the crosshair. + :type color: A string (either a predefined color name in Colors.py + or "#RRGGBB")) or a 4 columns unsigned byte array + (Default: black). + :param int linewidth: The width of the lines of the crosshair + (Default: 1). + :param str linestyle: Type of line:: + + - ' ' no line + - '-' solid line (the default) + - '--' dashed line + - '-.' dash-dot line + - ':' dotted line + """ + if flag: + self._cursorConfiguration = color, linewidth, linestyle + else: + self._cursorConfiguration = None + + self._backend.setGraphCursor(flag=flag, color=color, + linewidth=linewidth, linestyle=linestyle) + self._setDirtyPlot() + self.notify('setGraphCursor', + state=self._cursorConfiguration is not None) + + def pan(self, direction, factor=0.1): + """Pan the graph in the given direction by the given factor. + + Warning: Pan of right Y axis not implemented! + + :param str direction: One of 'up', 'down', 'left', 'right'. + :param float factor: Proportion of the range used to pan the graph. + Must be strictly positive. + """ + assert direction in ('up', 'down', 'left', 'right') + assert factor > 0. + + if direction in ('left', 'right'): + xFactor = factor if direction == 'right' else - factor + xMin, xMax = self.getGraphXLimits() + + xMin, xMax = _utils.applyPan(xMin, xMax, xFactor, + self.isXAxisLogarithmic()) + self.setGraphXLimits(xMin, xMax) + + else: # direction in ('up', 'down') + sign = -1. if self.isYAxisInverted() else 1. + yFactor = sign * (factor if direction == 'up' else -factor) + yMin, yMax = self.getGraphYLimits() + yIsLog = self.isYAxisLogarithmic() + + yMin, yMax = _utils.applyPan(yMin, yMax, yFactor, yIsLog) + self.setGraphYLimits(yMin, yMax, axis='left') + + y2Min, y2Max = self.getGraphYLimits(axis='right') + + y2Min, y2Max = _utils.applyPan(y2Min, y2Max, yFactor, yIsLog) + self.setGraphYLimits(y2Min, y2Max, axis='right') + + # Active Curve/Image + + def isActiveCurveHandling(self): + """Returns True if active curve selection is enabled.""" + return self._activeCurveHandling + + def setActiveCurveHandling(self, flag=True): + """Enable/Disable active curve selection. + + :param bool flag: True (the default) to enable active curve selection. + """ + if not flag: + self.setActiveCurve(None) # Reset active curve + + self._activeCurveHandling = bool(flag) + + def getActiveCurveColor(self): + """Get the color used to display the currently active curve. + + See :meth:`setActiveCurveColor`. + """ + return self._activeCurveColor + + def setActiveCurveColor(self, color="#000000"): + """Set the color to use to display the currently active curve. + + :param str color: Color of the active curve, + e.g., 'blue', 'b', '#FF0000' (Default: 'black') + """ + if color is None: + color = "black" + if color in self.colorDict: + color = self.colorDict[color] + self._activeCurveColor = color + + def getActiveCurve(self, just_legend=False): + """Return the currently active curve. + + It returns None in case of not having an active curve. + + :param bool just_legend: True to get the legend of the curve, + False (the default) to get the curve data + and info. + :return: Active curve's legend or corresponding + :class:`.items.Curve` + :rtype: str or :class:`.items.Curve` or None + """ + if not self.isActiveCurveHandling(): + return None + + return self._getActiveItem(kind='curve', just_legend=just_legend) + + def setActiveCurve(self, legend, replot=None): + """Make the curve associated to legend the active curve. + + :param legend: The legend associated to the curve + or None to have no active curve. + :type legend: str or None + """ + if replot is not None: + _logger.warning('setActiveCurve deprecated replot parameter') + + if not self.isActiveCurveHandling(): + return + + return self._setActiveItem(kind='curve', legend=legend) + + def getActiveImage(self, just_legend=False): + """Returns the currently active image. + + It returns None in case of not having an active image. + + :param bool just_legend: True to get the legend of the image, + False (the default) to get the image data + and info. + :return: Active image's legend or corresponding image object + :rtype: str, :class:`.items.ImageData`, :class:`.items.ImageRgba` + or None + """ + return self._getActiveItem(kind='image', just_legend=just_legend) + + def setActiveImage(self, legend, replot=None): + """Make the image associated to legend the active image. + + :param str legend: The legend associated to the image + or None to have no active image. + """ + if replot is not None: + _logger.warning('setActiveImage deprecated replot parameter') + + return self._setActiveItem(kind='image', legend=legend) + + def _getActiveItem(self, kind, just_legend=False): + """Return the currently active item of that kind if any + + :param str kind: Type of item: 'curve', 'scatter' or 'image' + :param bool just_legend: True to get the legend, + False (default) to get the item + :return: legend or item or None if no active item + """ + assert kind in ('curve', 'scatter', 'image') + + if self._activeLegend[kind] is None: + return None + + if (self._activeLegend[kind], kind) not in self._content: + self._activeLegend[kind] = None + return None + + if just_legend: + return self._activeLegend[kind] + else: + return self._getItem(kind, self._activeLegend[kind]) + + def _setActiveItem(self, kind, legend): + """Make the curve associated to legend the active curve. + + :param str kind: Type of item: 'curve' or 'image' + :param legend: The legend associated to the curve + or None to have no active curve. + :type legend: str or None + """ + assert kind in ('curve', 'image', 'scatter') + + xLabel = self._defaultLabels['x'] + yLabel = self._defaultLabels['y'] + yRightLabel = self._defaultLabels['yright'] + + oldActiveItem = self._getActiveItem(kind=kind) + + # Curve specific: Reset highlight of previous active curve + if kind == 'curve' and oldActiveItem is not None: + oldActiveItem.setHighlighted(False) + + if legend is None: + self._activeLegend[kind] = None + else: + legend = str(legend) + item = self._getItem(kind, legend) + if item is None: + _logger.warning("This %s does not exist: %s", kind, legend) + self._activeLegend[kind] = None + else: + self._activeLegend[kind] = legend + + # Curve specific: handle highlight + if kind == 'curve': + item.setHighlightedColor(self.getActiveCurveColor()) + item.setHighlighted(True) + + if isinstance(item, items.LabelsMixIn): + if item.getXLabel() is not None: + xLabel = item.getXLabel() + if item.getYLabel() is not None: + if (isinstance(item, items.YAxisMixIn) and + item.getYAxis() == 'right'): + yRightLabel = item.getYLabel() + else: + yLabel = item.getYLabel() + + # Store current labels and update plot + self._currentLabels['x'] = xLabel + self._currentLabels['y'] = yLabel + self._currentLabels['yright'] = yRightLabel + + self._backend.setGraphXLabel(xLabel) + self._backend.setGraphYLabel(yLabel, axis='left') + self._backend.setGraphYLabel(yRightLabel, axis='right') + + self._setDirtyPlot() + + activeLegend = self._activeLegend[kind] + if oldActiveItem is not None or activeLegend is not None: + if oldActiveItem is None: + oldActiveLegend = None + else: + oldActiveLegend = oldActiveItem.getLegend() + self.notify( + 'active' + kind[0].upper() + kind[1:] + 'Changed', + updated=oldActiveLegend != activeLegend, + previous=oldActiveLegend, + legend=activeLegend) + + return activeLegend + + # Getters + + def getAllCurves(self, just_legend=False, withhidden=False): + """Returns all curves legend or info and data. + + It returns an empty list in case of not having any curve. + + If just_legend is False, it returns a list of :class:`items.Curve` + objects describing the curves. + If just_legend is True, it returns a list of curves' legend. + + :param bool just_legend: True to get the legend of the curves, + False (the default) to get the curves' data + and info. + :param bool withhidden: False (default) to skip hidden curves. + :return: list of curves' legend or :class:`.items.Curve` + :rtype: list of str or list of :class:`.items.Curve` + """ + return self._getItems(kind='curve', + just_legend=just_legend, + withhidden=withhidden) + + def getCurve(self, legend=None): + """Get the object describing a specific curve. + + It returns None in case no matching curve is found. + + :param str legend: + The legend identifying the curve. + If not provided or None (the default), the active curve is returned + or if there is no active curve, the latest updated curve that is + not hidden is returned if there are curves in the plot. + :return: None or :class:`.items.Curve` object + """ + return self._getItem(kind='curve', legend=legend) + + def getAllImages(self, just_legend=False): + """Returns all images legend or objects. + + It returns an empty list in case of not having any image. + + If just_legend is False, it returns a list of :class:`items.ImageBase` + objects describing the images. + If just_legend is True, it returns a list of legends. + + :param bool just_legend: True to get the legend of the images, + False (the default) to get the images' + object. + :return: list of images' legend or :class:`.items.ImageBase` + :rtype: list of str or list of :class:`.items.ImageBase` + """ + return self._getItems(kind='image', + just_legend=just_legend, + withhidden=True) + + def getImage(self, legend=None): + """Get the object describing a specific image. + + It returns None in case no matching image is found. + + :param str legend: + The legend identifying the image. + If not provided or None (the default), the active image is returned + or if there is no active image, the latest updated image + is returned if there are images in the plot. + :return: None or :class:`.items.ImageBase` object + """ + return self._getItem(kind='image', legend=legend) + + def getScatter(self, legend=None): + """Get the object describing a specific scatter. + + It returns None in case no matching scatter is found. + + :param str legend: + The legend identifying the scatter. + If not provided or None (the default), the active scatter is + returned or if there is no active scatter, the latest updated + scatter is returned if there are scatters in the plot. + :return: None or :class:`.items.Scatter` object + """ + return self._getItem(kind='scatter', legend=legend) + + def getHistogram(self, legend=None): + """Get the object describing a specific histogram. + + It returns None in case no matching histogram is found. + + :param str legend: + The legend identifying the histogram. + If not provided or None (the default), the latest updated scatter + is returned if there are histograms in the plot. + :return: None or :class:`.items.Histogram` object + """ + return self._getItem(kind='histogram', legend=legend) + + def _getItems(self, kind, just_legend=False, withhidden=False): + """Retrieve all items of a kind in the plot + + :param str kind: Type of item: 'curve' or 'image' + :param bool just_legend: True to get the legend of the curves, + False (the default) to get the curves' data + and info. + :param bool withhidden: False (default) to skip hidden curves. + :return: list of legends or item objects + """ + assert kind in self.ITEM_KINDS + output = [] + for (legend, type_), item in self._content.items(): + if type_ == kind and (withhidden or item.isVisible()): + output.append(legend if just_legend else item) + return output + + def _getItem(self, kind, legend=None): + """Get an item from the plot: either an image or a curve. + + Returns None if no match found + + :param str kind: Type of item: 'curve' or 'image' + :param str legend: Legend of the item or + None to get active or last item + :return: Object describing the item or None + """ + assert kind in self.ITEM_KINDS + + if legend is not None: + return self._content.get((legend, kind), None) + else: + if kind in ('curve', 'image', 'scatter'): + item = self._getActiveItem(kind=kind) + if item is not None: # Return active item if available + return item + # Return last visible item if any + allItems = self._getItems( + kind=kind, just_legend=False, withhidden=False) + return allItems[-1] if allItems else None + + # Limits + + def _notifyLimitsChanged(self): + """Send an event when plot area limits are changed.""" + xRange = self.getGraphXLimits() + yRange = self.getGraphYLimits(axis='left') + y2Range = self.getGraphYLimits(axis='right') + event = PlotEvents.prepareLimitsChangedSignal( + id(self.getWidgetHandle()), xRange, yRange, y2Range) + self.notify(**event) + + def _checkLimits(self, min_, max_, axis): + """Makes sure axis range is not empty + + :param float min_: Min axis value + :param float max_: Max axis value + :param str axis: 'x', 'y' or 'y2' the axis to deal with + :return: (min, max) making sure min < max + :rtype: 2-tuple of float + """ + if max_ < min_: + _logger.info('%s axis: max < min, inverting limits.', axis) + min_, max_ = max_, min_ + elif max_ == min_: + _logger.info('%s axis: max == min, expanding limits.', axis) + if min_ == 0.: + min_, max_ = -0.1, 0.1 + elif min_ < 0: + min_, max_ = min_ * 1.1, min_ * 0.9 + else: # xmin > 0 + min_, max_ = min_ * 0.9, min_ * 1.1 + + return min_, max_ + + def getGraphXLimits(self): + """Get the graph X (bottom) limits. + + :return: Minimum and maximum values of the X axis + """ + return self._backend.getGraphXLimits() + + def setGraphXLimits(self, xmin, xmax, replot=None): + """Set the graph X (bottom) limits. + + :param float xmin: minimum bottom axis value + :param float xmax: maximum bottom axis value + """ + if replot is not None: + _logger.warning('setGraphXLimits deprecated replot parameter') + + xmin, xmax = self._checkLimits(xmin, xmax, axis='x') + + self._backend.setGraphXLimits(xmin, xmax) + self._setDirtyPlot() + + self._notifyLimitsChanged() + + def getGraphYLimits(self, axis='left'): + """Get the graph Y limits. + + :param str axis: The axis for which to get the limits: + Either 'left' or 'right' + :return: Minimum and maximum values of the X axis + """ + assert axis in ('left', 'right') + return self._backend.getGraphYLimits(axis) + + def setGraphYLimits(self, ymin, ymax, axis='left', replot=None): + """Set the graph Y limits. + + :param float ymin: minimum bottom axis value + :param float ymax: maximum bottom axis value + :param str axis: The axis for which to get the limits: + Either 'left' or 'right' + """ + if replot is not None: + _logger.warning('setGraphYLimits deprecated replot parameter') + + assert axis in ('left', 'right') + + ymin, ymax = self._checkLimits(ymin, ymax, + axis='y' if axis == 'left' else 'y2') + + self._backend.setGraphYLimits(ymin, ymax, axis) + self._setDirtyPlot() + + self._notifyLimitsChanged() + + def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None): + """Set the limits of the X and Y axes at once. + + If y2min or y2max is None, the right Y axis limits are not updated. + + :param float xmin: minimum bottom axis value + :param float xmax: maximum bottom axis value + :param float ymin: minimum left axis value + :param float ymax: maximum left axis value + :param float y2min: minimum right axis value or None (the default) + :param float y2max: maximum right axis value or None (the default) + """ + # Deal with incorrect values + xmin, xmax = self._checkLimits(xmin, xmax, axis='x') + ymin, ymax = self._checkLimits(ymin, ymax, axis='y') + + if y2min is None or y2max is None: + # if one limit is None, both are ignored + y2min, y2max = None, None + else: + y2min, y2max = self._checkLimits(y2min, y2max, axis='y2') + + self._backend.setLimits(xmin, xmax, ymin, ymax, y2min, y2max) + self._setDirtyPlot() + self._notifyLimitsChanged() + + # Title and labels + + def getGraphTitle(self): + """Return the plot main title as a str.""" + return self._graphTitle + + def setGraphTitle(self, title=""): + """Set the plot main title. + + :param str title: Main title of the plot (default: '') + """ + self._graphTitle = str(title) + self._backend.setGraphTitle(title) + self._setDirtyPlot() + + def getGraphXLabel(self): + """Return the current X axis label as a str.""" + return self._currentLabels['x'] + + def setGraphXLabel(self, label="X"): + """Set the plot X axis label. + + The provided label can be temporarily replaced by the X label of the + active curve if any. + + :param str label: The X axis label (default: 'X') + """ + self._defaultLabels['x'] = label + self._currentLabels['x'] = label + self._backend.setGraphXLabel(label) + self._setDirtyPlot() + + def getGraphYLabel(self, axis='left'): + """Return the current Y axis label as a str. + + :param str axis: The Y axis for which to get the label (left or right) + """ + assert axis in ('left', 'right') + + return self._currentLabels['y' if axis == 'left' else 'yright'] + + def setGraphYLabel(self, label="Y", axis='left'): + """Set the plot Y axis label. + + The provided label can be temporarily replaced by the Y label of the + active curve if any. + + :param str label: The Y axis label (default: 'Y') + :param str axis: The Y axis for which to set the label (left or right) + """ + assert axis in ('left', 'right') + + if axis == 'left': + self._defaultLabels['y'] = label + self._currentLabels['y'] = label + else: + self._defaultLabels['yright'] = label + self._currentLabels['yright'] = label + + self._backend.setGraphYLabel(label, axis=axis) + self._setDirtyPlot() + + # Axes + + def setYAxisInverted(self, flag=True): + """Set the Y axis orientation. + + :param bool flag: True for Y axis going from top to bottom, + False for Y axis going from bottom to top + """ + flag = bool(flag) + self._backend.setYAxisInverted(flag) + self._setDirtyPlot() + self.notify('setYAxisInverted', state=flag) + + def isYAxisInverted(self): + """Return True if Y axis goes from top to bottom, False otherwise.""" + return self._backend.isYAxisInverted() + + def isXAxisLogarithmic(self): + """Return True if X axis scale is logarithmic, False if linear.""" + return self._logX + + def setXAxisLogarithmic(self, flag): + """Set the bottom X axis scale (either linear or logarithmic). + + :param bool flag: True to use a logarithmic scale, False for linear. + """ + if bool(flag) == self._logX: + return + self._logX = bool(flag) + + self._backend.setXAxisLogarithmic(self._logX) + + # TODO hackish way of forcing update of curves and images + for curve in self.getAllCurves(): + curve._updated() + for image in self.getAllImages(): + image._updated() + self._invalidateDataRange() + + self.resetZoom() + self.notify('setXAxisLogarithmic', state=self._logX) + + def isYAxisLogarithmic(self): + """Return True if Y axis scale is logarithmic, False if linear.""" + return self._logY + + def setYAxisLogarithmic(self, flag): + """Set the Y axes scale (either linear or logarithmic). + + :param bool flag: True to use a logarithmic scale, False for linear. + """ + if bool(flag) == self._logY: + return + self._logY = bool(flag) + + self._backend.setYAxisLogarithmic(self._logY) + + # TODO hackish way of forcing update of curves and images + for curve in self.getAllCurves(): + curve._updated() + for image in self.getAllImages(): + image._updated() + self._invalidateDataRange() + + self.resetZoom() + self.notify('setYAxisLogarithmic', state=self._logY) + + def isXAxisAutoScale(self): + """Return True if X axis is automatically adjusting its limits.""" + return self._xAutoScale + + def setXAxisAutoScale(self, flag=True): + """Set the X axis limits adjusting behavior of :meth:`resetZoom`. + + :param bool flag: True to resize limits automatically, + False to disable it. + """ + self._xAutoScale = bool(flag) + self.notify('setXAxisAutoScale', state=self._xAutoScale) + + def isYAxisAutoScale(self): + """Return True if Y axes are automatically adjusting its limits.""" + return self._yAutoScale + + def setYAxisAutoScale(self, flag=True): + """Set the Y axis limits adjusting behavior of :meth:`resetZoom`. + + :param bool flag: True to resize limits automatically, + False to disable it. + """ + self._yAutoScale = bool(flag) + self.notify('setYAxisAutoScale', state=self._yAutoScale) + + def isKeepDataAspectRatio(self): + """Returns whether the plot is keeping data aspect ratio or not.""" + return self._backend.isKeepDataAspectRatio() + + def setKeepDataAspectRatio(self, flag=True): + """Set whether the plot keeps data aspect ratio or not. + + :param bool flag: True to respect data aspect ratio + """ + flag = bool(flag) + self._backend.setKeepDataAspectRatio(flag=flag) + self._setDirtyPlot() + self.resetZoom() + self.notify('setKeepDataAspectRatio', state=flag) + + def getGraphGrid(self): + """Return the current grid mode, either None, 'major' or 'both'. + + See :meth:`setGraphGrid`. + """ + return self._grid + + def setGraphGrid(self, which=True): + """Set the type of grid to display. + + :param which: None or False to disable the grid, + 'major' or True for grid on major ticks (the default), + 'both' for grid on both major and minor ticks. + :type which: str of bool + """ + assert which in (None, True, False, 'both', 'major') + if not which: + which = None + elif which is True: + which = 'major' + self._grid = which + self._backend.setGraphGrid(which) + self._setDirtyPlot() + self.notify('setGraphGrid', which=str(which)) + + # Defaults + + def isDefaultPlotPoints(self): + """Return True if default Curve symbol is 'o', False for no symbol.""" + return self._defaultPlotPoints == 'o' + + def setDefaultPlotPoints(self, flag): + """Set the default symbol of all curves. + + When called, this reset the symbol of all existing curves. + + :param bool flag: True to use 'o' as the default curve symbol, + False to use no symbol. + """ + self._defaultPlotPoints = 'o' if flag else '' + + # Reset symbol of all curves + curves = self.getAllCurves(just_legend=False, withhidden=True) + + if curves: + for curve in curves: + curve.setSymbol(self._defaultPlotPoints) + + def isDefaultPlotLines(self): + """Return True for line as default line style, False for no line.""" + return self._plotLines + + def setDefaultPlotLines(self, flag): + """Toggle the use of lines as the default curve line style. + + :param bool flag: True to use a line as the default line style, + False to use no line as the default line style. + """ + self._plotLines = bool(flag) + + linestyle = '-' if self._plotLines else ' ' + + # Reset linestyle of all curves + curves = self.getAllCurves(withhidden=True) + + if curves: + for curve in curves: + curve.setLineStyle(linestyle) + + def getDefaultColormap(self): + """Return the default colormap used by :meth:`addImage` as a dict. + + See :mod:`Plot` for the documentation of the colormap dict. + """ + return self._defaultColormap.copy() + + def setDefaultColormap(self, colormap=None): + """Set the default colormap used by :meth:`addImage`. + + Setting the default colormap do not change any currently displayed + image. + It only affects future calls to :meth:`addImage` without the colormap + parameter. + + :param dict colormap: The description of the default colormap, or + None to set the colormap to a linear autoscale + gray colormap. + See :mod:`Plot` for the documentation + of the colormap dict. + """ + if colormap is None: + colormap = {'name': 'gray', 'normalization': 'linear', + 'autoscale': True, 'vmin': 0.0, 'vmax': 1.0} + self._defaultColormap = colormap.copy() + + def getSupportedColormaps(self): + """Get the supported colormap names as a tuple of str. + + The list should at least contain and start by: + ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue') + """ + default = ('gray', 'reversed gray', + 'temperature', + 'red', 'green', 'blue') + if matplotlib_cm is None: + return default + else: + maps = [m for m in matplotlib_cm.datad] + maps.sort() + return default + tuple(maps) + + def _getColorAndStyle(self): + color = self.colorList[self._colorIndex] + style = self._styleList[self._styleIndex] + + # Loop over color and then styles + self._colorIndex += 1 + if self._colorIndex >= len(self.colorList): + self._colorIndex = 0 + self._styleIndex = (self._styleIndex + 1) % len(self._styleList) + + # If color is the one of active curve, take the next one + if color == self.getActiveCurveColor(): + color, style = self._getColorAndStyle() + + if not self._plotLines: + style = ' ' + + return color, style + + # Misc. + + def getWidgetHandle(self): + """Return the widget the plot is displayed in. + + This widget is owned by the backend. + """ + return self._backend.getWidgetHandle() + + def notify(self, event, **kwargs): + """Send an event to the listeners. + + Event are passed to the registered callback as a dict with an 'event' + key for backward compatibility with PyMca. + + :param str event: The type of event + :param kwargs: The information of the event. + """ + eventDict = kwargs.copy() + eventDict['event'] = event + self._callback(eventDict) + + def setCallback(self, callbackFunction=None): + """Attach a listener to the backend. + + Limitation: Only one listener at a time. + + :param callbackFunction: function accepting a dictionary as input + to handle the graph events + If None (default), use a default listener. + """ + # TODO allow multiple listeners, keep a weakref on it + # allow register listener by event type + if callbackFunction is None: + callbackFunction = self.graphCallback + self._callback = callbackFunction + + def graphCallback(self, ddict=None): + """This callback is going to receive all the events from the plot. + + Those events will consist on a dictionary and among the dictionary + keys the key 'event' is mandatory to describe the type of event. + This default implementation only handles setting the active curve. + """ + + if ddict is None: + ddict = {} + _logger.debug("Received dict keys = %s", str(ddict.keys())) + _logger.debug(str(ddict)) + if ddict['event'] in ["legendClicked", "curveClicked"]: + if ddict['button'] == "left": + self.setActiveCurve(ddict['label']) + + def saveGraph(self, filename, fileFormat=None, dpi=None, **kw): + """Save a snapshot of the plot. + + Supported file formats: "png", "svg", "pdf", "ps", "eps", + "tif", "tiff", "jpeg", "jpg". + + :param filename: Destination + :type filename: str, StringIO or BytesIO + :param str fileFormat: String specifying the format + :return: False if cannot save the plot, True otherwise + """ + if kw: + _logger.warning('Extra parameters ignored: %s', str(kw)) + + if fileFormat is None: + if not hasattr(filename, 'lower'): + _logger.warning( + 'saveGraph cancelled, cannot define file format.') + return False + else: + fileFormat = (filename.split(".")[-1]).lower() + + supportedFormats = ("png", "svg", "pdf", "ps", "eps", + "tif", "tiff", "jpeg", "jpg") + + if fileFormat not in supportedFormats: + _logger.warning('Unsupported format %s', fileFormat) + return False + else: + self._backend.saveGraph(filename, + fileFormat=fileFormat, + dpi=dpi) + return True + + def getDataMargins(self): + """Get the default data margin ratios, see :meth:`setDataMargins`. + + :return: The margin ratios for each side (xMin, xMax, yMin, yMax). + :rtype: A 4-tuple of floats. + """ + return self._defaultDataMargins + + def setDataMargins(self, xMinMargin=0., xMaxMargin=0., + yMinMargin=0., yMaxMargin=0.): + """Set the default data margins to use in :meth:`resetZoom`. + + Set the default ratios of margins (as floats) to add around the data + inside the plot area for each side. + """ + self._defaultDataMargins = (xMinMargin, xMaxMargin, + yMinMargin, yMaxMargin) + + def getAutoReplot(self): + """Return True if replot is automatically handled, False otherwise. + + See :meth`setAutoReplot`. + """ + return self._autoreplot + + def setAutoReplot(self, autoreplot=True): + """Set automatic replot mode. + + When enabled, the plot is redrawn automatically when changed. + When disabled, the plot is not redrawn when its content change. + Instead, it :meth:`replot` must be called. + + :param bool autoreplot: True to enable it (default), + False to disable it. + """ + self._autoreplot = bool(autoreplot) + + # If the plot is dirty before enabling autoreplot, + # then _backend.postRedisplay will never be called from _setDirtyPlot + if self._autoreplot and self._getDirtyPlot(): + self._backend.postRedisplay() + + def replot(self): + """Redraw the plot immediately.""" + for item in self._contentToUpdate: + item._update(self._backend) + self._contentToUpdate.clear() + self._backend.replot() + self._dirty = False # reset dirty flag + + def resetZoom(self, dataMargins=None): + """Reset the plot limits to the bounds of the data and redraw the plot. + + It automatically scale limits of axes that are in autoscale mode + (See :meth:`setXAxisAutoScale`, :meth:`setYAxisAutoScale`). + It keeps current limits on axes that are not in autoscale mode. + + Extra margins can be added around the data inside the plot area. + Margins are given as one ratio of the data range per limit of the + data (xMin, xMax, yMin and yMax limits). + For log scale, extra margins are applied in log10 of the data. + + :param dataMargins: Ratios of margins to add around the data inside + the plot area for each side (Default: no margins). + :type dataMargins: A 4-tuple of float as (xMin, xMax, yMin, yMax). + """ + if dataMargins is None: + dataMargins = self._defaultDataMargins + + xLimits = self.getGraphXLimits() + yLimits = self.getGraphYLimits(axis='left') + y2Limits = self.getGraphYLimits(axis='right') + + xAuto = self.isXAxisAutoScale() + yAuto = self.isYAxisAutoScale() + + if not xAuto and not yAuto: + _logger.debug("Nothing to autoscale") + else: # Some axes to autoscale + + # Get data range + ranges = self.getDataRange() + xmin, xmax = (1., 100.) if ranges.x is None else ranges.x + ymin, ymax = (1., 100.) if ranges.y is None else ranges.y + if ranges.yright is None: + ymin2, ymax2 = None, None + else: + ymin2, ymax2 = ranges.yright + + # Add margins around data inside the plot area + newLimits = list(_utils.addMarginsToLimits( + dataMargins, + self.isXAxisLogarithmic(), + self.isYAxisLogarithmic(), + xmin, xmax, ymin, ymax, ymin2, ymax2)) + + if self.isKeepDataAspectRatio(): + # Use limits with margins to keep ratio + xmin, xmax, ymin, ymax = newLimits[:4] + + # Compute bbox wth figure aspect ratio + plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:] + plotRatio = plotHeight / plotWidth + + if plotRatio > 0.: + dataRatio = (ymax - ymin) / (xmax - xmin) + if dataRatio < plotRatio: + # Increase y range + ycenter = 0.5 * (ymax + ymin) + yrange = (xmax - xmin) * plotRatio + newLimits[2] = ycenter - 0.5 * yrange + newLimits[3] = ycenter + 0.5 * yrange + + elif dataRatio > plotRatio: + # Increase x range + xcenter = 0.5 * (xmax + xmin) + xrange_ = (ymax - ymin) / plotRatio + newLimits[0] = xcenter - 0.5 * xrange_ + newLimits[1] = xcenter + 0.5 * xrange_ + + self.setLimits(*newLimits) + + if not xAuto and yAuto: + self.setGraphXLimits(*xLimits) + elif xAuto and not yAuto: + if y2Limits is not None: + self.setGraphYLimits( + y2Limits[0], y2Limits[1], axis='right') + if yLimits is not None: + self.setGraphYLimits(yLimits[0], yLimits[1], axis='left') + + self._setDirtyPlot() + + if (xLimits != self.getGraphXLimits() or + yLimits != self.getGraphYLimits(axis='left') or + y2Limits != self.getGraphYLimits(axis='right')): + self._notifyLimitsChanged() + + # Coord conversion + + def dataToPixel(self, x=None, y=None, axis="left", check=True): + """Convert a position in data coordinates to a position in pixels. + + :param float x: The X coordinate in data space. If None (default) + the middle position of the displayed data is used. + :param float y: The Y coordinate in data space. If None (default) + the middle position of the displayed data is used. + :param str axis: The Y axis to use for the conversion + ('left' or 'right'). + :param bool check: True to return None if outside displayed area, + False to convert to pixels anyway + :returns: The corresponding position in pixels or + None if the data position is not in the displayed area and + check is True. + :rtype: A tuple of 2 floats: (xPixel, yPixel) or None. + """ + assert axis in ("left", "right") + + xmin, xmax = self.getGraphXLimits() + ymin, ymax = self.getGraphYLimits(axis=axis) + + if x is None: + x = 0.5 * (xmax + xmin) + if y is None: + y = 0.5 * (ymax + ymin) + + if check: + if x > xmax or x < xmin: + return None + + if y > ymax or y < ymin: + return None + + return self._backend.dataToPixel(x, y, axis=axis) + + def pixelToData(self, x, y, axis="left", check=False): + """Convert a position in pixels to a position in data coordinates. + + :param float x: The X coordinate in pixels. If None (default) + the center of the widget is used. + :param float y: The Y coordinate in pixels. If None (default) + the center of the widget is used. + :param str axis: The Y axis to use for the conversion + ('left' or 'right'). + :param bool check: Toggle checking if pixel is in plot area. + If False, this method never returns None. + :returns: The corresponding position in data space or + None if the pixel position is not in the plot area. + :rtype: A tuple of 2 floats: (xData, yData) or None. + """ + assert axis in ("left", "right") + return self._backend.pixelToData(x, y, axis=axis, check=check) + + def getPlotBoundsInPixels(self): + """Plot area bounds in widget coordinates in pixels. + + :return: bounds as a 4-tuple of int: (left, top, width, height) + """ + return self._backend.getPlotBoundsInPixels() + + # Interaction support + + def setGraphCursorShape(self, cursor=None): + """Set the cursor shape. + + :param str cursor: Name of the cursor shape + """ + self._backend.setGraphCursorShape(cursor) + + def _pickMarker(self, x, y, test=None): + """Pick a marker at the given position. + + To use for interaction implementation. + + :param float x: X position in pixels. + :param float y: Y position in pixels. + :param test: A callable to call for each picked marker to filter + picked markers. If None (default), do not filter markers. + """ + if test is None: + def test(mark): + return True + + markers = self._backend.pickItems(x, y) + legends = [m['legend'] for m in markers if m['kind'] == 'marker'] + + for legend in reversed(legends): + marker = self._getMarker(legend) + if marker is not None and test(marker): + return marker + return None + + def _getAllMarkers(self, just_legend=False): + """Returns all markers' legend or objects + + :param bool just_legend: True to get the legend of the markers, + False (the default) to get marker objects. + :return: list of legend of list of marker objects + :rtype: list of str or list of marker objects + """ + return self._getItems( + kind='marker', just_legend=just_legend, withhidden=True) + + def _getMarker(self, legend=None): + """Get the object describing a specific marker. + + It returns None in case no matching marker is found + + :param str legend: The legend of the marker to retrieve + :rtype: None of marker object + """ + return self._getItem(kind='marker', legend=legend) + + def _pickImageOrCurve(self, x, y, test=None): + """Pick an image or a curve at the given position. + + To use for interaction implementation. + + :param float x: X position in pixelsparam float y: Y position in pixels + :param test: A callable to call for each picked item to filter + picked items. If None (default), do not filter items. + """ + if test is None: + def test(i): + return True + + allItems = self._backend.pickItems(x, y) + allItems = [item for item in allItems + if item['kind'] in ['curve', 'image']] + + for item in reversed(allItems): + kind, legend = item['kind'], item['legend'] + if kind == 'curve': + curve = self.getCurve(legend) + if curve is not None and test(curve): + return kind, curve, item['xdata'], item['ydata'] + + elif kind == 'image': + image = self.getImage(legend) + if image is not None and test(image): + return kind, image, None + + else: + _logger.warning('Unsupported kind: %s', kind) + + return None + + # User event handling # + + def _isPositionInPlotArea(self, x, y): + """Project position in pixel to the closest point in the plot area + + :param float x: X coordinate in widget coordinate (in pixel) + :param float y: Y coordinate in widget coordinate (in pixel) + :return: (x, y) in widget coord (in pixel) in the plot area + """ + left, top, width, height = self.getPlotBoundsInPixels() + xPlot = numpy.clip(x, left, left + width) + yPlot = numpy.clip(y, top, top + height) + return xPlot, yPlot + + def onMousePress(self, xPixel, yPixel, btn): + """Handle mouse press event. + + :param float xPixel: X mouse position in pixels + :param float yPixel: Y mouse position in pixels + :param str btn: Mouse button in 'left', 'middle', 'right' + """ + if self._isPositionInPlotArea(xPixel, yPixel) == (xPixel, yPixel): + self._pressedButtons.append(btn) + self._eventHandler.handleEvent('press', xPixel, yPixel, btn) + + def onMouseMove(self, xPixel, yPixel): + """Handle mouse move event. + + :param float xPixel: X mouse position in pixels + :param float yPixel: Y mouse position in pixels + """ + inXPixel, inYPixel = self._isPositionInPlotArea(xPixel, yPixel) + isCursorInPlot = inXPixel == xPixel and inYPixel == yPixel + + if self._cursorInPlot != isCursorInPlot: + self._cursorInPlot = isCursorInPlot + self._eventHandler.handleEvent( + 'enter' if self._cursorInPlot else 'leave') + + if isCursorInPlot: + # Signal mouse move event + dataPos = self.pixelToData(inXPixel, inYPixel) + assert dataPos is not None + + btn = self._pressedButtons[-1] if self._pressedButtons else None + event = PlotEvents.prepareMouseSignal( + 'mouseMoved', btn, dataPos[0], dataPos[1], xPixel, yPixel) + self.notify(**event) + + # Either button was pressed in the plot or cursor is in the plot + if isCursorInPlot or self._pressedButtons: + self._eventHandler.handleEvent('move', inXPixel, inYPixel) + + def onMouseRelease(self, xPixel, yPixel, btn): + """Handle mouse release event. + + :param float xPixel: X mouse position in pixels + :param float yPixel: Y mouse position in pixels + :param str btn: Mouse button in 'left', 'middle', 'right' + """ + try: + self._pressedButtons.remove(btn) + except ValueError: + pass + else: + xPixel, yPixel = self._isPositionInPlotArea(xPixel, yPixel) + self._eventHandler.handleEvent('release', xPixel, yPixel, btn) + + def onMouseWheel(self, xPixel, yPixel, angleInDegrees): + """Handle mouse wheel event. + + :param float xPixel: X mouse position in pixels + :param float yPixel: Y mouse position in pixels + :param float angleInDegrees: Angle corresponding to wheel motion. + Positive for movement away from the user, + negative for movement toward the user. + """ + if self._isPositionInPlotArea(xPixel, yPixel) == (xPixel, yPixel): + self._eventHandler.handleEvent( + 'wheel', xPixel, yPixel, angleInDegrees) + + def onMouseLeaveWidget(self): + """Handle mouse leave widget event.""" + if self._cursorInPlot: + self._cursorInPlot = False + self._eventHandler.handleEvent('leave') + + # Interaction modes # + + def getInteractiveMode(self): + """Returns the current interactive mode as a dict. + + The returned dict contains at least the key 'mode'. + Mode can be: 'draw', 'pan', 'select', 'zoom'. + It can also contains extra keys (e.g., 'color') specific to a mode + as provided to :meth:`setInteractiveMode`. + """ + return self._eventHandler.getInteractiveMode() + + def setInteractiveMode(self, mode, color='black', + shape='polygon', label=None, + zoomOnWheel=True, source=None, width=None): + """Switch the interactive mode. + + :param str mode: The name of the interactive mode. + In 'draw', 'pan', 'select', 'zoom'. + :param color: Only for 'draw' and 'zoom' modes. + Color to use for drawing selection area. Default black. + :type color: Color description: The name as a str or + a tuple of 4 floats. + :param str shape: Only for 'draw' mode. The kind of shape to draw. + In 'polygon', 'rectangle', 'line', 'vline', 'hline', + 'freeline'. + Default is 'polygon'. + :param str label: Only for 'draw' mode, sent in drawing events. + :param bool zoomOnWheel: Toggle zoom on wheel support + :param source: A user-defined object (typically the caller object) + that will be send in the interactiveModeChanged event, + to identify which object required a mode change. + Default: None + :param float width: Width of the pencil. Only for draw pencil mode. + """ + self._eventHandler.setInteractiveMode(mode, color, shape, label, width) + self._eventHandler.zoomOnWheel = zoomOnWheel + + self.notify( + 'interactiveModeChanged', source=source) + + # Deprecated # + + def isDrawModeEnabled(self): + """Deprecated, use :meth:`getInteractiveMode` instead. + + Return True if the current interactive state is drawing.""" + _logger.warning( + 'isDrawModeEnabled deprecated, use getInteractiveMode instead') + return self.getInteractiveMode()['mode'] == 'draw' + + def setDrawModeEnabled(self, flag=True, shape='polygon', label=None, + color=None, **kwargs): + """Deprecated, use :meth:`setInteractiveMode` instead. + + Set the drawing mode if flag is True and its parameters. + + If flag is False, only item selection is enabled. + + Warning: Zoom and drawing are not compatible and cannot be enabled + simultaneously. + + :param bool flag: True to enable drawing and disable zoom and select. + :param str shape: Type of item to be drawn in: + hline, vline, rectangle, polygon (default) + :param str label: Associated text for identifying draw signals + :param color: The color to use to draw the selection area + :type color: string ("#RRGGBB") or 4 column unsigned byte array or + one of the predefined color names defined in Colors.py + """ + _logger.warning( + 'setDrawModeEnabled deprecated, use setInteractiveMode instead') + + if kwargs: + _logger.warning('setDrawModeEnabled ignores additional parameters') + + if color is None: + color = 'black' + + if flag: + self.setInteractiveMode('draw', shape=shape, + label=label, color=color) + elif self.getInteractiveMode()['mode'] == 'draw': + self.setInteractiveMode('select') + + def getDrawMode(self): + """Deprecated, use :meth:`getInteractiveMode` instead. + + Return the draw mode parameters as a dict of None. + + It returns None if the interactive mode is not a drawing mode, + otherwise, it returns a dict containing the drawing mode parameters + as provided to :meth:`setDrawModeEnabled`. + """ + _logger.warning( + 'getDrawMode deprecated, use getInteractiveMode instead') + mode = self.getInteractiveMode() + return mode if mode['mode'] == 'draw' else None + + def isZoomModeEnabled(self): + """Deprecated, use :meth:`getInteractiveMode` instead. + + Return True if the current interactive state is zooming.""" + _logger.warning( + 'isZoomModeEnabled deprecated, use getInteractiveMode instead') + return self.getInteractiveMode()['mode'] == 'zoom' + + def setZoomModeEnabled(self, flag=True, color=None): + """Deprecated, use :meth:`setInteractiveMode` instead. + + Set the zoom mode if flag is True, else item selection is enabled. + + Warning: Zoom and drawing are not compatible and cannot be enabled + simultaneously + + :param bool flag: If True, enable zoom and select mode. + :param color: The color to use to draw the selection area. + (Default: 'black') + :param color: The color to use to draw the selection area + :type color: string ("#RRGGBB") or 4 column unsigned byte array or + one of the predefined color names defined in Colors.py + """ + _logger.warning( + 'setZoomModeEnabled deprecated, use setInteractiveMode instead') + if color is None: + color = 'black' + + if flag: + self.setInteractiveMode('zoom', color=color) + elif self.getInteractiveMode()['mode'] == 'zoom': + self.setInteractiveMode('select') + + def insertMarker(self, *args, **kwargs): + """Deprecated, use :meth:`addMarker` instead.""" + _logger.warning( + 'insertMarker deprecated, use addMarker instead.') + return self.addMarker(*args, **kwargs) + + def insertXMarker(self, *args, **kwargs): + """Deprecated, use :meth:`addXMarker` instead.""" + _logger.warning( + 'insertXMarker deprecated, use addXMarker instead.') + return self.addXMarker(*args, **kwargs) + + def insertYMarker(self, *args, **kwargs): + """Deprecated, use :meth:`addYMarker` instead.""" + _logger.warning( + 'insertYMarker deprecated, use addYMarker instead.') + return self.addYMarker(*args, **kwargs) + + def isActiveCurveHandlingEnabled(self): + """Deprecated, use :meth:`isActiveCurveHandling` instead.""" + _logger.warning( + 'isActiveCurveHandlingEnabled deprecated, ' + 'use isActiveCurveHandling instead.') + return self.isActiveCurveHandling() + + def enableActiveCurveHandling(self, *args, **kwargs): + """Deprecated, use :meth:`setActiveCurveHandling` instead.""" + _logger.warning( + 'enableActiveCurveHandling deprecated, ' + 'use setActiveCurveHandling instead.') + return self.setActiveCurveHandling(*args, **kwargs) + + def invertYAxis(self, *args, **kwargs): + """Deprecated, use :meth:`setYAxisInverted` instead.""" + _logger.warning('invertYAxis deprecated, ' + 'use setYAxisInverted instead.') + return self.setYAxisInverted(*args, **kwargs) + + def showGrid(self, flag=True): + """Deprecated, use :meth:`setGraphGrid` instead.""" + _logger.warning("showGrid deprecated, use setGraphGrid instead") + if flag in (0, False): + flag = None + elif flag in (1, True): + flag = 'major' + else: + flag = 'both' + return self.setGraphGrid(flag) + + def keepDataAspectRatio(self, *args, **kwargs): + """Deprecated, use :meth:`setKeepDataAspectRatio`.""" + _logger.warning('keepDataAspectRatio deprecated,' + 'use setKeepDataAspectRatio instead') + return self.setKeepDataAspectRatio(*args, **kwargs) diff --git a/silx/gui/plot/PlotActions.py b/silx/gui/plot/PlotActions.py new file mode 100644 index 0000000..aad27d2 --- /dev/null +++ b/silx/gui/plot/PlotActions.py @@ -0,0 +1,1386 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a set of QAction to use with :class:`.PlotWidget`. + +The following QAction are available: + +- :class:`ColormapAction` +- :class:`CopyAction` +- :class:`CrosshairAction` +- :class:`CurveStyleAction` +- :class:`FitAction` +- :class:`GridAction` +- :class:`KeepAspectRatioAction` +- :class:`PanWithArrowKeysAction` +- :class:`PrintAction` +- :class:`ResetZoomAction` +- :class:`SaveAction` +- :class:`XAxisLogarithmicAction` +- :class:`XAxisAutoScaleAction` +- :class:`YAxisInvertedAction` +- :class:`YAxisLogarithmicAction` +- :class:`YAxisAutoScaleAction` +- :class:`ZoomInAction` +- :class:`ZoomOutAction` +""" + +from __future__ import division + + +__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"] +__license__ = "MIT" +__date__ = "20/04/2017" + + +from collections import OrderedDict +import logging +import sys +import traceback +import weakref + +if sys.version_info[0] == 3: + from io import BytesIO +else: + import cStringIO as _StringIO + BytesIO = _StringIO.StringIO + +import numpy + +from .. import icons +from .. import qt +from .._utils import convertArrayToQImage +from . import Colors, items +from .ColormapDialog import ColormapDialog +from ._utils import applyZoomToPlot as _applyZoomToPlot +from silx.third_party.EdfFile import EdfFile +from silx.third_party.TiffIO import TiffIO +from silx.math.histogram import Histogramnd +from silx.math.medianfilter import medfilt2d +from silx.gui.widgets.MedianFilterDialog import MedianFilterDialog + +from silx.io.utils import save1D, savespec + + +_logger = logging.getLogger(__name__) + + +class PlotAction(qt.QAction): + """Base class for QAction that operates on a PlotWidget. + + :param plot: :class:`.PlotWidget` instance on which to operate. + :param icon: QIcon or str name of icon to use + :param str text: The name of this action to be used for menu label + :param str tooltip: The text of the tooltip + :param triggered: The callback to connect to the action's triggered + signal or None for no callback. + :param bool checkable: True for checkable action, False otherwise (default) + :param parent: See :class:`QAction`. + """ + + def __init__(self, plot, icon, text, tooltip=None, + triggered=None, checkable=False, parent=None): + assert plot is not None + self._plotRef = weakref.ref(plot) + + if not isinstance(icon, qt.QIcon): + # Try with icon as a string and load corresponding icon + icon = icons.getQIcon(icon) + + super(PlotAction, self).__init__(icon, text, parent) + + if tooltip is not None: + self.setToolTip(tooltip) + + self.setCheckable(checkable) + + if triggered is not None: + self.triggered[bool].connect(triggered) + + @property + def plot(self): + """The :class:`.PlotWidget` this action group is controlling.""" + return self._plotRef() + + +class ResetZoomAction(PlotAction): + """QAction controlling reset zoom on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + super(ResetZoomAction, self).__init__( + plot, icon='zoom-original', text='Reset Zoom', + tooltip='Auto-scale the graph', + triggered=self._actionTriggered, + checkable=False, parent=parent) + self._autoscaleChanged(True) + plot.sigSetXAxisAutoScale.connect(self._autoscaleChanged) + plot.sigSetYAxisAutoScale.connect(self._autoscaleChanged) + + def _autoscaleChanged(self, enabled): + self.setEnabled( + self.plot.isXAxisAutoScale() or self.plot.isYAxisAutoScale()) + + if self.plot.isXAxisAutoScale() and self.plot.isYAxisAutoScale(): + tooltip = 'Auto-scale the graph' + elif self.plot.isXAxisAutoScale(): # And not Y axis + tooltip = 'Auto-scale the x-axis of the graph only' + elif self.plot.isYAxisAutoScale(): # And not X axis + tooltip = 'Auto-scale the y-axis of the graph only' + else: # no axis in autoscale + tooltip = 'Auto-scale the graph' + self.setToolTip(tooltip) + + def _actionTriggered(self, checked=False): + self.plot.resetZoom() + + +class ZoomInAction(PlotAction): + """QAction performing a zoom-in on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + super(ZoomInAction, self).__init__( + plot, icon='zoom-in', text='Zoom In', + tooltip='Zoom in the plot', + triggered=self._actionTriggered, + checkable=False, parent=parent) + self.setShortcut(qt.QKeySequence.ZoomIn) + self.setShortcutContext(qt.Qt.WidgetShortcut) + + def _actionTriggered(self, checked=False): + _applyZoomToPlot(self.plot, 1.1) + + +class ZoomOutAction(PlotAction): + """QAction performing a zoom-out on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + super(ZoomOutAction, self).__init__( + plot, icon='zoom-out', text='Zoom Out', + tooltip='Zoom out the plot', + triggered=self._actionTriggered, + checkable=False, parent=parent) + self.setShortcut(qt.QKeySequence.ZoomOut) + self.setShortcutContext(qt.Qt.WidgetShortcut) + + def _actionTriggered(self, checked=False): + _applyZoomToPlot(self.plot, 1. / 1.1) + + +class XAxisAutoScaleAction(PlotAction): + """QAction controlling X axis autoscale on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + super(XAxisAutoScaleAction, self).__init__( + plot, icon='plot-xauto', text='X Autoscale', + tooltip='Enable x-axis auto-scale when checked.\n' + 'If unchecked, x-axis does not change when reseting zoom.', + triggered=self._actionTriggered, + checkable=True, parent=parent) + self.setChecked(plot.isXAxisAutoScale()) + plot.sigSetXAxisAutoScale.connect(self.setChecked) + + def _actionTriggered(self, checked=False): + self.plot.setXAxisAutoScale(checked) + if checked: + self.plot.resetZoom() + + +class YAxisAutoScaleAction(PlotAction): + """QAction controlling Y axis autoscale on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + super(YAxisAutoScaleAction, self).__init__( + plot, icon='plot-yauto', text='Y Autoscale', + tooltip='Enable y-axis auto-scale when checked.\n' + 'If unchecked, y-axis does not change when reseting zoom.', + triggered=self._actionTriggered, + checkable=True, parent=parent) + self.setChecked(plot.isXAxisAutoScale()) + plot.sigSetYAxisAutoScale.connect(self.setChecked) + + def _actionTriggered(self, checked=False): + self.plot.setYAxisAutoScale(checked) + if checked: + self.plot.resetZoom() + + +class XAxisLogarithmicAction(PlotAction): + """QAction controlling X axis log scale on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + super(XAxisLogarithmicAction, self).__init__( + plot, icon='plot-xlog', text='X Log. scale', + tooltip='Logarithmic x-axis when checked', + triggered=self._actionTriggered, + checkable=True, parent=parent) + self.setChecked(plot.isXAxisLogarithmic()) + plot.sigSetXAxisLogarithmic.connect(self.setChecked) + + def _actionTriggered(self, checked=False): + self.plot.setXAxisLogarithmic(checked) + + +class YAxisLogarithmicAction(PlotAction): + """QAction controlling Y axis log scale on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + super(YAxisLogarithmicAction, self).__init__( + plot, icon='plot-ylog', text='Y Log. scale', + tooltip='Logarithmic y-axis when checked', + triggered=self._actionTriggered, + checkable=True, parent=parent) + self.setChecked(plot.isYAxisLogarithmic()) + plot.sigSetYAxisLogarithmic.connect(self.setChecked) + + def _actionTriggered(self, checked=False): + self.plot.setYAxisLogarithmic(checked) + + +class GridAction(PlotAction): + """QAction controlling grid mode on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param str gridMode: The grid mode to use in 'both', 'major'. + See :meth:`.PlotWidget.setGraphGrid` + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, gridMode='both', parent=None): + assert gridMode in ('both', 'major') + self._gridMode = gridMode + + super(GridAction, self).__init__( + plot, icon='plot-grid', text='Grid', + tooltip='Toggle grid (on/off)', + triggered=self._actionTriggered, + checkable=True, parent=parent) + self.setChecked(plot.getGraphGrid() is not None) + plot.sigSetGraphGrid.connect(self._gridChanged) + + def _gridChanged(self, which): + """Slot listening for PlotWidget grid mode change.""" + self.setChecked(which != 'None') + + def _actionTriggered(self, checked=False): + self.plot.setGraphGrid(self._gridMode if checked else None) + + +class CurveStyleAction(PlotAction): + """QAction controlling curve style on a :class:`.PlotWidget`. + + It changes the default line and markers style which updates all + curves on the plot. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + super(CurveStyleAction, self).__init__( + plot, icon='plot-toggle-points', text='Curve style', + tooltip='Change curve line and markers style', + triggered=self._actionTriggered, + checkable=False, parent=parent) + + def _actionTriggered(self, checked=False): + currentState = (self.plot.isDefaultPlotLines(), + self.plot.isDefaultPlotPoints()) + + # line only, line and symbol, symbol only + states = (True, False), (True, True), (False, True) + newState = states[(states.index(currentState) + 1) % 3] + + self.plot.setDefaultPlotLines(newState[0]) + self.plot.setDefaultPlotPoints(newState[1]) + + +class ColormapAction(PlotAction): + """QAction opening a ColormapDialog to update the colormap. + + Both the active image colormap and the default colormap are updated. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + def __init__(self, plot, parent=None): + self._dialog = None # To store an instance of ColormapDialog + super(ColormapAction, self).__init__( + plot, icon='colormap', text='Colormap', + tooltip="Change colormap", + triggered=self._actionTriggered, + checkable=False, parent=parent) + + def _actionTriggered(self, checked=False): + """Create a cmap dialog and update active image and default cmap.""" + # Create the dialog if not already existing + if self._dialog is None: + self._dialog = ColormapDialog() + + image = self.plot.getActiveImage() + if not isinstance(image, items.ColormapMixIn): + # No active image or active image is RGBA, + # set dialog from default info + colormap = self.plot.getDefaultColormap() + + self._dialog.setHistogram() # Reset histogram and range if any + + else: + # Set dialog from active image + colormap = image.getColormap() + + data = image.getData(copy=False) + + goodData = data[numpy.isfinite(data)] + if goodData.size > 0: + dataMin = goodData.min() + dataMax = goodData.max() + else: + qt.QMessageBox.warning( + self, "No Data", + "Image data does not contain any real value") + dataMin, dataMax = 1., 10. + + self._dialog.setHistogram() # Reset histogram if any + self._dialog.setDataRange(dataMin, dataMax) + # The histogram should be done in a worker thread + # hist, bin_edges = numpy.histogram(goodData, bins=256) + # self._dialog.setHistogram(hist, bin_edges) + + self._dialog.setColormap(**colormap) + + # Run the dialog listening to colormap change + self._dialog.sigColormapChanged.connect(self._colormapChanged) + result = self._dialog.exec_() + self._dialog.sigColormapChanged.disconnect(self._colormapChanged) + + if not result: # Restore the previous colormap + self._colormapChanged(colormap) + + def _colormapChanged(self, colormap): + # Update default colormap + self.plot.setDefaultColormap(colormap) + + # Update active image colormap + activeImage = self.plot.getActiveImage() + if isinstance(activeImage, items.ColormapMixIn): + activeImage.setColormap(colormap) + + +class KeepAspectRatioAction(PlotAction): + """QAction controlling aspect ratio on a :class:`.PlotWidget`. + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + # Uses two images for checked/unchecked states + self._states = { + False: (icons.getQIcon('shape-circle-solid'), + "Keep data aspect ratio"), + True: (icons.getQIcon('shape-ellipse-solid'), + "Do no keep data aspect ratio") + } + + icon, tooltip = self._states[plot.isKeepDataAspectRatio()] + super(KeepAspectRatioAction, self).__init__( + plot, + icon=icon, + text='Toggle keep aspect ratio', + tooltip=tooltip, + triggered=self._actionTriggered, + checkable=False, + parent=parent) + plot.sigSetKeepDataAspectRatio.connect( + self._keepDataAspectRatioChanged) + + def _keepDataAspectRatioChanged(self, aspectRatio): + """Handle Plot set keep aspect ratio signal""" + icon, tooltip = self._states[aspectRatio] + self.setIcon(icon) + self.setToolTip(tooltip) + + def _actionTriggered(self, checked=False): + # This will trigger _keepDataAspectRatioChanged + self.plot.setKeepDataAspectRatio(not self.plot.isKeepDataAspectRatio()) + + +class YAxisInvertedAction(PlotAction): + """QAction controlling Y orientation on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + # Uses two images for checked/unchecked states + self._states = { + False: (icons.getQIcon('plot-ydown'), + "Orient Y axis downward"), + True: (icons.getQIcon('plot-yup'), + "Orient Y axis upward"), + } + + icon, tooltip = self._states[plot.isYAxisInverted()] + super(YAxisInvertedAction, self).__init__( + plot, + icon=icon, + text='Invert Y Axis', + tooltip=tooltip, + triggered=self._actionTriggered, + checkable=False, + parent=parent) + plot.sigSetYAxisInverted.connect(self._yAxisInvertedChanged) + + def _yAxisInvertedChanged(self, inverted): + """Handle Plot set y axis inverted signal""" + icon, tooltip = self._states[inverted] + self.setIcon(icon) + self.setToolTip(tooltip) + + def _actionTriggered(self, checked=False): + # This will trigger _yAxisInvertedChanged + self.plot.setYAxisInverted(not self.plot.isYAxisInverted()) + + +class SaveAction(PlotAction): + """QAction for saving Plot content. + + It opens a Save as... dialog. + + :param plot: :class:`.PlotWidget` instance on which to operate. + :param parent: See :class:`QAction`. + """ + # TODO find a way to make the filter list selectable and extensible + + SNAPSHOT_FILTER_SVG = 'Plot Snapshot as SVG (*.svg)' + + SNAPSHOT_FILTERS = ('Plot Snapshot as PNG (*.png)', + 'Plot Snapshot as JPEG (*.jpg)', + SNAPSHOT_FILTER_SVG) + + # Dict of curve filters with CSV-like format + # Using ordered dict to guarantee filters order + # Note: '%.18e' is numpy.savetxt default format + CURVE_FILTERS_TXT = OrderedDict(( + ('Curve as Raw ASCII (*.txt)', + {'fmt': '%.18e', 'delimiter': ' ', 'header': False}), + ('Curve as ";"-separated CSV (*.csv)', + {'fmt': '%.18e', 'delimiter': ';', 'header': True}), + ('Curve as ","-separated CSV (*.csv)', + {'fmt': '%.18e', 'delimiter': ',', 'header': True}), + ('Curve as tab-separated CSV (*.csv)', + {'fmt': '%.18e', 'delimiter': '\t', 'header': True}), + ('Curve as OMNIC CSV (*.csv)', + {'fmt': '%.7E', 'delimiter': ',', 'header': False}), + ('Curve as SpecFile (*.dat)', + {'fmt': '%.7g', 'delimiter': '', 'header': False}) + )) + + CURVE_FILTER_NPY = 'Curve as NumPy binary file (*.npy)' + + CURVE_FILTERS = list(CURVE_FILTERS_TXT.keys()) + [CURVE_FILTER_NPY] + + ALL_CURVES_FILTERS = ("All curves as SpecFile (*.dat)", ) + + IMAGE_FILTER_EDF = 'Image data as EDF (*.edf)' + IMAGE_FILTER_TIFF = 'Image data as TIFF (*.tif)' + IMAGE_FILTER_NUMPY = 'Image data as NumPy binary file (*.npy)' + IMAGE_FILTER_ASCII = 'Image data as ASCII (*.dat)' + IMAGE_FILTER_CSV_COMMA = 'Image data as ,-separated CSV (*.csv)' + IMAGE_FILTER_CSV_SEMICOLON = 'Image data as ;-separated CSV (*.csv)' + IMAGE_FILTER_CSV_TAB = 'Image data as tab-separated CSV (*.csv)' + IMAGE_FILTER_RGB_PNG = 'Image as PNG (*.png)' + IMAGE_FILTER_RGB_TIFF = 'Image as TIFF (*.tif)' + IMAGE_FILTERS = (IMAGE_FILTER_EDF, + IMAGE_FILTER_TIFF, + IMAGE_FILTER_NUMPY, + IMAGE_FILTER_ASCII, + IMAGE_FILTER_CSV_COMMA, + IMAGE_FILTER_CSV_SEMICOLON, + IMAGE_FILTER_CSV_TAB, + IMAGE_FILTER_RGB_PNG, + IMAGE_FILTER_RGB_TIFF) + + def __init__(self, plot, parent=None): + super(SaveAction, self).__init__( + plot, icon='document-save', text='Save as...', + tooltip='Save curve/image/plot snapshot dialog', + triggered=self._actionTriggered, + checkable=False, parent=parent) + self.setShortcut(qt.QKeySequence.Save) + self.setShortcutContext(qt.Qt.WidgetShortcut) + + def _errorMessage(self, informativeText=''): + """Display an error message.""" + # TODO issue with QMessageBox size fixed and too small + msg = qt.QMessageBox(self.plot) + msg.setIcon(qt.QMessageBox.Critical) + msg.setInformativeText(informativeText + ' ' + str(sys.exc_info()[1])) + msg.setDetailedText(traceback.format_exc()) + msg.exec_() + + def _saveSnapshot(self, filename, nameFilter): + """Save a snapshot of the :class:`PlotWindow` widget. + + :param str filename: The name of the file to write + :param str nameFilter: The selected name filter + :return: False if format is not supported or save failed, + True otherwise. + """ + if nameFilter == self.SNAPSHOT_FILTER_SVG: + self.plot.saveGraph(filename, fileFormat='svg') + + else: + if hasattr(qt.QPixmap, "grabWidget"): + # Qt 4 + pixmap = qt.QPixmap.grabWidget(self.plot.getWidgetHandle()) + else: + # Qt 5 + pixmap = self.plot.getWidgetHandle().grab() + if not pixmap.save(filename): + self._errorMessage() + return False + return True + + def _saveCurve(self, filename, nameFilter): + """Save a curve from the plot. + + :param str filename: The name of the file to write + :param str nameFilter: The selected name filter + :return: False if format is not supported or save failed, + True otherwise. + """ + if nameFilter not in self.CURVE_FILTERS: + return False + + # Check if a curve is to be saved + curve = self.plot.getActiveCurve() + # before calling _saveCurve, if there is no selected curve, we + # make sure there is only one curve on the graph + if curve is None: + curves = self.plot.getAllCurves() + if not curves: + self._errorMessage("No curve to be saved") + return False + curve = curves[0] + + if nameFilter in self.CURVE_FILTERS_TXT: + filter_ = self.CURVE_FILTERS_TXT[nameFilter] + fmt = filter_['fmt'] + csvdelim = filter_['delimiter'] + autoheader = filter_['header'] + else: + # .npy + fmt, csvdelim, autoheader = ("", "", False) + + # If curve has no associated label, get the default from the plot + xlabel = curve.getXLabel() + if xlabel is None: + xlabel = self.plot.getGraphXLabel() + ylabel = curve.getYLabel() + if ylabel is None: + ylabel = self.plot.getGraphYLabel() + + try: + save1D(filename, + curve.getXData(copy=False), + curve.getYData(copy=False), + xlabel, [ylabel], + fmt=fmt, csvdelim=csvdelim, + autoheader=autoheader) + except IOError: + self._errorMessage('Save failed\n') + return False + + return True + + def _saveCurves(self, filename, nameFilter): + """Save all curves from the plot. + + :param str filename: The name of the file to write + :param str nameFilter: The selected name filter + :return: False if format is not supported or save failed, + True otherwise. + """ + if nameFilter not in self.ALL_CURVES_FILTERS: + return False + + curves = self.plot.getAllCurves() + if not curves: + self._errorMessage("No curves to be saved") + return False + + curve = curves[0] + scanno = 1 + try: + specfile = savespec(filename, + curve.getXData(copy=False), + curve.getYData(copy=False), + curve.getXLabel(), + curve.getYLabel(), + fmt="%.7g", scan_number=1, mode="w", + write_file_header=True, + close_file=False) + except IOError: + self._errorMessage('Save failed\n') + return False + + for curve in curves[1:]: + try: + scanno += 1 + specfile = savespec(specfile, + curve.getXData(copy=False), + curve.getYData(copy=False), + curve.getXLabel(), + curve.getYLabel(), + fmt="%.7g", scan_number=scanno, mode="w", + write_file_header=False, + close_file=False) + except IOError: + self._errorMessage('Save failed\n') + return False + specfile.close() + + return True + + def _saveImage(self, filename, nameFilter): + """Save an image from the plot. + + :param str filename: The name of the file to write + :param str nameFilter: The selected name filter + :return: False if format is not supported or save failed, + True otherwise. + """ + if nameFilter not in self.IMAGE_FILTERS: + return False + + image = self.plot.getActiveImage() + if image is None: + qt.QMessageBox.warning( + self.plot, "No Data", "No image to be saved") + return False + + data = image.getData(copy=False) + + # TODO Use silx.io for writing files + if nameFilter == self.IMAGE_FILTER_EDF: + edfFile = EdfFile(filename, access="w+") + edfFile.WriteImage({}, data, Append=0) + return True + + elif nameFilter == self.IMAGE_FILTER_TIFF: + tiffFile = TiffIO(filename, mode='w') + tiffFile.writeImage(data, software='silx') + return True + + elif nameFilter == self.IMAGE_FILTER_NUMPY: + try: + numpy.save(filename, data) + except IOError: + self._errorMessage('Save failed\n') + return False + return True + + elif nameFilter in (self.IMAGE_FILTER_ASCII, + self.IMAGE_FILTER_CSV_COMMA, + self.IMAGE_FILTER_CSV_SEMICOLON, + self.IMAGE_FILTER_CSV_TAB): + csvdelim, filetype = { + self.IMAGE_FILTER_ASCII: (' ', 'txt'), + self.IMAGE_FILTER_CSV_COMMA: (',', 'csv'), + self.IMAGE_FILTER_CSV_SEMICOLON: (';', 'csv'), + self.IMAGE_FILTER_CSV_TAB: ('\t', 'csv'), + }[nameFilter] + + height, width = data.shape + rows, cols = numpy.mgrid[0:height, 0:width] + try: + save1D(filename, rows.ravel(), (cols.ravel(), data.ravel()), + filetype=filetype, + xlabel='row', + ylabels=['column', 'value'], + csvdelim=csvdelim, + autoheader=True) + + except IOError: + self._errorMessage('Save failed\n') + return False + return True + + elif nameFilter in (self.IMAGE_FILTER_RGB_PNG, + self.IMAGE_FILTER_RGB_TIFF): + # Get displayed image + rgbaImage = image.getRbgaImageData(copy=False) + # Convert RGB QImage + qimage = convertArrayToQImage(rgbaImage[:, :, :3]) + + if nameFilter == self.IMAGE_FILTER_RGB_PNG: + fileFormat = 'PNG' + else: + fileFormat = 'TIFF' + + if qimage.save(filename, fileFormat): + return True + else: + _logger.error('Failed to save image as %s', filename) + qt.QMessageBox.critical( + self.parent(), + 'Save image as', + 'Failed to save image') + + return False + + def _actionTriggered(self, checked=False): + """Handle save action.""" + # Set-up filters + filters = [] + + # Add image filters if there is an active image + if self.plot.getActiveImage() is not None: + filters.extend(self.IMAGE_FILTERS) + + # Add curve filters if there is a curve to save + if (self.plot.getActiveCurve() is not None or + len(self.plot.getAllCurves()) == 1): + filters.extend(self.CURVE_FILTERS) + if len(self.plot.getAllCurves()) > 1: + filters.extend(self.ALL_CURVES_FILTERS) + + filters.extend(self.SNAPSHOT_FILTERS) + + # Create and run File dialog + dialog = qt.QFileDialog(self.plot) + dialog.setWindowTitle("Output File Selection") + dialog.setModal(1) + dialog.setNameFilters(filters) + + dialog.setFileMode(dialog.AnyFile) + dialog.setAcceptMode(dialog.AcceptSave) + + if not dialog.exec_(): + return False + + nameFilter = dialog.selectedNameFilter() + filename = dialog.selectedFiles()[0] + dialog.close() + + # Forces the filename extension to match the chosen filter + extension = nameFilter.split()[-1][2:-1] + if (len(filename) <= len(extension) or + filename[-len(extension):].lower() != extension.lower()): + filename += extension + + # Handle save + if nameFilter in self.SNAPSHOT_FILTERS: + return self._saveSnapshot(filename, nameFilter) + elif nameFilter in self.CURVE_FILTERS: + return self._saveCurve(filename, nameFilter) + elif nameFilter in self.ALL_CURVES_FILTERS: + return self._saveCurves(filename, nameFilter) + elif nameFilter in self.IMAGE_FILTERS: + return self._saveImage(filename, nameFilter) + else: + _logger.warning('Unsupported file filter: %s', nameFilter) + return False + + +def _plotAsPNG(plot): + """Save a :class:`Plot` as PNG and return the payload. + + :param plot: The :class:`Plot` to save + """ + pngFile = BytesIO() + plot.saveGraph(pngFile, fileFormat='png') + pngFile.flush() + pngFile.seek(0) + data = pngFile.read() + pngFile.close() + return data + + +class PrintAction(PlotAction): + """QAction for printing the plot. + + It opens a Print dialog. + + Current implementation print a bitmap of the plot area and not vector + graphics, so printing quality is not great. + + :param plot: :class:`.PlotWidget` instance on which to operate. + :param parent: See :class:`QAction`. + """ + + # Share QPrinter instance to propose latest used as default + _printer = None + + def __init__(self, plot, parent=None): + super(PrintAction, self).__init__( + plot, icon='document-print', text='Print...', + tooltip='Open print dialog', + triggered=self.printPlot, + checkable=False, parent=parent) + self.setShortcut(qt.QKeySequence.Print) + self.setShortcutContext(qt.Qt.WidgetShortcut) + + @property + def printer(self): + """The QPrinter instance used by the actions. + + This is shared accross all instances of PrintAct + """ + if self._printer is None: + PrintAction._printer = qt.QPrinter() + return self._printer + + def printPlotAsWidget(self): + """Open the print dialog and print the plot. + + Use :meth:`QWidget.render` to print the plot + + :return: True if successful + """ + dialog = qt.QPrintDialog(self.printer, self.plot) + dialog.setWindowTitle('Print Plot') + if not dialog.exec_(): + return False + + # Print a snapshot of the plot widget at the top of the page + widget = self.plot.centralWidget() + + painter = qt.QPainter() + if not painter.begin(self.printer): + return False + + pageRect = self.printer.pageRect() + xScale = pageRect.width() / widget.width() + yScale = pageRect.height() / widget.height() + scale = min(xScale, yScale) + + painter.translate(pageRect.width() / 2., 0.) + painter.scale(scale, scale) + painter.translate(-widget.width() / 2., 0.) + widget.render(painter) + painter.end() + + return True + + def printPlot(self): + """Open the print dialog and print the plot. + + Use :meth:`Plot.saveGraph` to print the plot. + + :return: True if successful + """ + # Init printer and start printer dialog + dialog = qt.QPrintDialog(self.printer, self.plot) + dialog.setWindowTitle('Print Plot') + if not dialog.exec_(): + return False + + # Save Plot as PNG and make a pixmap from it with default dpi + pngData = _plotAsPNG(self.plot) + + pixmap = qt.QPixmap() + pixmap.loadFromData(pngData, 'png') + + xScale = self.printer.pageRect().width() / pixmap.width() + yScale = self.printer.pageRect().height() / pixmap.height() + scale = min(xScale, yScale) + + # Draw pixmap with painter + painter = qt.QPainter() + if not painter.begin(self.printer): + return False + + painter.drawPixmap(0, 0, + pixmap.width() * scale, + pixmap.height() * scale, + pixmap) + painter.end() + + return True + + +class CopyAction(PlotAction): + """QAction to copy :class:`.PlotWidget` content to clipboard. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + super(CopyAction, self).__init__( + plot, icon='edit-copy', text='Copy plot', + tooltip='Copy a snapshot of the plot into the clipboard', + triggered=self.copyPlot, + checkable=False, parent=parent) + self.setShortcut(qt.QKeySequence.Copy) + self.setShortcutContext(qt.Qt.WidgetShortcut) + + def copyPlot(self): + """Copy plot content to the clipboard as a bitmap.""" + # Save Plot as PNG and make a QImage from it with default dpi + pngData = _plotAsPNG(self.plot) + image = qt.QImage.fromData(pngData, 'png') + qt.QApplication.clipboard().setImage(image) + + +class CrosshairAction(PlotAction): + """QAction toggling crosshair cursor on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param str color: Color to use to draw the crosshair + :param int linewidth: Width of the crosshair cursor + :param str linestyle: Style of line. See :meth:`.Plot.setGraphCursor` + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, color='black', linewidth=1, linestyle='-', + parent=None): + self.color = color + """Color used to draw the crosshair (str).""" + + self.linewidth = linewidth + """Width of the crosshair cursor (int).""" + + self.linestyle = linestyle + """Style of line of the cursor (str).""" + + super(CrosshairAction, self).__init__( + plot, icon='crosshair', text='Crosshair Cursor', + tooltip='Enable crosshair cursor when checked', + triggered=self._actionTriggered, + checkable=True, parent=parent) + self.setChecked(plot.getGraphCursor() is not None) + plot.sigSetGraphCursor.connect(self.setChecked) + + def _actionTriggered(self, checked=False): + self.plot.setGraphCursor(checked, + color=self.color, + linestyle=self.linestyle, + linewidth=self.linewidth) + + +class PanWithArrowKeysAction(PlotAction): + """QAction toggling pan with arrow keys on a :class:`.PlotWidget`. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + + super(PanWithArrowKeysAction, self).__init__( + plot, icon='arrow-keys', text='Pan with arrow keys', + tooltip='Enable pan with arrow keys when checked', + triggered=self._actionTriggered, + checkable=True, parent=parent) + self.setChecked(plot.isPanWithArrowKeys()) + plot.sigSetPanWithArrowKeys.connect(self.setChecked) + + def _actionTriggered(self, checked=False): + self.plot.setPanWithArrowKeys(checked) + + +def _warningMessage(informativeText='', detailedText='', parent=None): + """Display a popup warning message.""" + msg = qt.QMessageBox(parent) + msg.setIcon(qt.QMessageBox.Warning) + msg.setInformativeText(informativeText) + msg.setDetailedText(detailedText) + msg.exec_() + + +def _getOneCurve(plt, mode="unique"): + """Get a single curve from the plot. + By default, get the active curve if any, else if a single curve is plotted + get it, else return None and display a warning popup. + + This behavior can be adjusted by modifying the *mode* parameter: always + return the active curve if any, but adjust the behavior in case no curve + is active. + + :param plt: :class:`.PlotWidget` instance on which to operate + :param mode: Parameter defining the behavior when no curve is active. + Possible modes: + - "none": return None (enforce curve activation) + - "unique": return the unique curve or None if multiple curves + - "first": return first curve + - "last": return last curve (most recently added one) + :return: return value of plt.getActiveCurve(), or plt.getAllCurves()[0], + or plt.getAllCurves()[-1], or None + """ + curve = plt.getActiveCurve() + if curve is not None: + return curve + + if mode is None or mode.lower() == "none": + _warningMessage("You must activate a curve!", + parent=plt) + return None + + curves = plt.getAllCurves() + if len(curves) == 0: + _warningMessage("No curve on this plot.", + parent=plt) + return None + + if len(curves) == 1: + return curves[0] + + if len(curves) > 1: + if mode == "unique": + _warningMessage("Multiple curves are plotted. " + + "Please activate the one you want to use.", + parent=plt) + return None + if mode.lower() == "first": + return curves[0] + if mode.lower() == "last": + return curves[-1] + + raise ValueError("Illegal value for parameter 'mode'." + + " Allowed values: 'none', 'unique', 'first', 'last'.") + + +class FitAction(PlotAction): + """QAction to open a :class:`FitWidget` and set its data to the + active curve if any, or to the first curve. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + def __init__(self, plot, parent=None): + super(FitAction, self).__init__( + plot, icon='math-fit', text='Fit curve', + tooltip='Open a fit dialog', + triggered=self._getFitWindow, + checkable=False, parent=parent) + self.fit_window = None + + def _getFitWindow(self): + curve = _getOneCurve(self.plot) + if curve is None: + return + self.xlabel = self.plot.getGraphXLabel() + self.ylabel = self.plot.getGraphYLabel() + self.x = curve.getXData(copy=False) + self.y = curve.getYData(copy=False) + self.legend = curve.getLegend() + self.xmin, self.xmax = self.plot.getGraphXLimits() + + # open a window with a FitWidget + if self.fit_window is None: + self.fit_window = qt.QMainWindow() + # import done here rather than at module level to avoid circular import + # FitWidget -> BackgroundWidget -> PlotWindow -> PlotActions -> FitWidget + from ..fit.FitWidget import FitWidget + self.fit_widget = FitWidget(parent=self.fit_window) + self.fit_window.setCentralWidget( + self.fit_widget) + self.fit_widget.guibuttons.DismissButton.clicked.connect( + self.fit_window.close) + self.fit_widget.sigFitWidgetSignal.connect( + self.handle_signal) + self.fit_window.show() + else: + if self.fit_window.isHidden(): + self.fit_window.show() + self.fit_widget.show() + self.fit_window.raise_() + + self.fit_widget.setData(self.x, self.y, + xmin=self.xmin, xmax=self.xmax) + self.fit_window.setWindowTitle( + "Fitting " + self.legend + + " on x range %f-%f" % (self.xmin, self.xmax)) + + def handle_signal(self, ddict): + x_fit = self.x[self.xmin <= self.x] + x_fit = x_fit[x_fit <= self.xmax] + fit_legend = "Fit <%s>" % self.legend + fit_curve = self.plot.getCurve(fit_legend) + + if ddict["event"] == "FitFinished": + y_fit = self.fit_widget.fitmanager.gendata() + if fit_curve is None: + self.plot.addCurve(x_fit, y_fit, + fit_legend, + xlabel=self.xlabel, ylabel=self.ylabel, + resetzoom=False) + else: + fit_curve.setData(x_fit, y_fit) + fit_curve.setVisible(True) + + if ddict["event"] in ["FitStarted", "FitFailed"]: + if fit_curve is not None: + fit_curve.setVisible(False) + + +class PixelIntensitiesHistoAction(PlotAction): + """QAction to plot the pixels intensities diagram + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + PlotAction.__init__(self, + plot, + icon='pixel-intensities', + text='pixels intensity', + tooltip='Compute image intensity distribution', + triggered=self._triggered, + parent=parent, + checkable=True) + self._plotHistogram = None + self._connectedToActiveImage = False + self._histo = None + + def _triggered(self, checked): + """Update the plot of the histogram visibility status + + :param bool checked: status of the action button + """ + if checked: + if not self._connectedToActiveImage: + self.plot.sigActiveImageChanged.connect( + self._activeImageChanged) + self._connectedToActiveImage = True + self.computeIntensityDistribution() + + self.getHistogramPlotWidget().show() + + else: + if self._connectedToActiveImage: + self.plot.sigActiveImageChanged.disconnect( + self._activeImageChanged) + self._connectedToActiveImage = False + + self.getHistogramPlotWidget().hide() + + def _activeImageChanged(self, previous, legend): + """Handle active image change: toggle enabled toolbar, update curve""" + if self.isChecked(): + self.computeIntensityDistribution() + + def computeIntensityDistribution(self): + """Get the active image and compute the image intensity distribution + """ + activeImage = self.plot.getActiveImage() + + if activeImage is not None: + image = activeImage.getData(copy=False) + if image.ndim == 3: # RGB(A) images + _logger.info('Converting current image from RGB(A) to grayscale\ + in order to compute the intensity distribution') + image = (image[:, :, 0] * 0.299 + + image[:, :, 1] * 0.587 + + image[:, :, 2] * 0.114) + + xmin = numpy.nanmin(image) + xmax = numpy.nanmax(image) + nbins = min(1024, int(numpy.sqrt(image.size))) + data_range = xmin, xmax + + # bad hack: get 256 bins in the case we have a B&W + if numpy.issubdtype(image.dtype, numpy.integer): + if nbins > xmax - xmin: + nbins = xmax - xmin + + nbins = max(2, nbins) + + data = image.ravel().astype(numpy.float32) + histogram = Histogramnd(data, n_bins=nbins, histo_range=data_range) + assert len(histogram.edges) == 1 + self._histo = histogram.histo + edges = histogram.edges[0] + plot = self.getHistogramPlotWidget() + plot.addHistogram(histogram=self._histo, + edges=edges, + legend='pixel intensity', + fill=True, + color='red') + plot.resetZoom() + + def eventFilter(self, qobject, event): + """Observe when the close event is emitted then + simply uncheck the action button + + :param qobject: the object observe + :param event: the event received by qobject + """ + if event.type() == qt.QEvent.Close: + if self._plotHistogram is not None: + self._plotHistogram.hide() + self.setChecked(False) + + return PlotAction.eventFilter(self, qobject, event) + + def getHistogramPlotWidget(self): + """Create the plot histogram if needed, otherwise create it + + :return: the PlotWidget showing the histogram of the pixel intensities + """ + from silx.gui.plot.PlotWindow import Plot1D + if self._plotHistogram is None: + self._plotHistogram = Plot1D(parent=self.plot) + self._plotHistogram.setWindowFlags(qt.Qt.Window) + self._plotHistogram.setWindowTitle('Image Intensity Histogram') + self._plotHistogram.installEventFilter(self) + self._plotHistogram.setGraphXLabel("Value") + self._plotHistogram.setGraphYLabel("Count") + + return self._plotHistogram + + def getHistogram(self): + """Return the last computed histogram + + :return: the histogram displayed in the HistogramPlotWiget + """ + return self._histo + + +class MedianFilterAction(PlotAction): + """QAction to plot the pixels intensities diagram + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + PlotAction.__init__(self, + plot, + icon='median-filter', + text='median filter', + tooltip='Apply a median filter on the image', + triggered=self._triggered, + parent=parent) + self._originalImage = None + self._legend = None + self._filteredImage = None + self._popup = MedianFilterDialog(parent=None) + self._popup.sigFilterOptChanged.connect(self._updateFilter) + self.plot.sigActiveImageChanged.connect( self._updateActiveImage) + self._updateActiveImage() + + def _triggered(self, checked): + """Update the plot of the histogram visibility status + + :param bool checked: status of the action button + """ + self._popup.show() + + def _updateActiveImage(self): + """Set _activeImageLegend and _originalImage from the active image""" + self._activeImageLegend = self.plot.getActiveImage(just_legend=True) + if self._activeImageLegend is None: + self._originalImage = None + self._legend = None + else: + self._originalImage = self.plot.getImage(self._activeImageLegend).getData(copy=False) + self._legend = self.plot.getImage(self._activeImageLegend).getLegend() + + def _updateFilter(self, kernelWidth, conditional=False): + if self._originalImage is None: + return + + self.plot.sigActiveImageChanged.disconnect(self._updateActiveImage) + filteredImage = self._computeFilteredImage(kernelWidth, conditional) + self.plot.addImage(data=filteredImage, + legend=self._legend, + replace=True) + self.plot.sigActiveImageChanged.connect(self._updateActiveImage) + + def _computeFilteredImage(self, kernelWidth, conditional): + raise NotImplemented('MedianFilterAction is a an abstract class') + + def getFilteredImage(self): + """ + :return: the image with the median filter apply on""" + return self._filteredImage + + +class MedianFilter1DAction(MedianFilterAction): + """Define the MedianFilterAction for 1D + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + def __init__(self, plot, parent=None): + MedianFilterAction.__init__(self, + plot, + parent=parent) + + def _computeFilteredImage(self, kernelWidth, conditional): + assert(self.plot is not None) + return medfilt2d(self._originalImage, + (kernelWidth, 1), + conditional) + + +class MedianFilter2DAction(MedianFilterAction): + """Define the MedianFilterAction for 2D + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + def __init__(self, plot, parent=None): + MedianFilterAction.__init__(self, + plot, + parent=parent) + + def _computeFilteredImage(self, kernelWidth, conditional): + assert(self.plot is not None) + return medfilt2d(self._originalImage, + (kernelWidth, kernelWidth), + conditional) diff --git a/silx/gui/plot/PlotEvents.py b/silx/gui/plot/PlotEvents.py new file mode 100644 index 0000000..83f253c --- /dev/null +++ b/silx/gui/plot/PlotEvents.py @@ -0,0 +1,166 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Functions to prepare events to be sent to Plot callback.""" + +__author__ = ["V.A. Sole", "T. Vincent"] +__license__ = "MIT" +__date__ = "18/02/2016" + + +import numpy as np + + +def prepareDrawingSignal(event, type_, points, parameters=None): + """See Plot documentation for content of events""" + assert event in ('drawingProgress', 'drawingFinished') + + if parameters is None: + parameters = {} + + eventDict = {} + eventDict['event'] = event + eventDict['type'] = type_ + points = np.array(points, dtype=np.float32) + points.shape = -1, 2 + eventDict['points'] = points + eventDict['xdata'] = points[:, 0] + eventDict['ydata'] = points[:, 1] + if type_ in ('rectangle',): + eventDict['x'] = eventDict['xdata'].min() + eventDict['y'] = eventDict['ydata'].min() + eventDict['width'] = eventDict['xdata'].max() - eventDict['x'] + eventDict['height'] = eventDict['ydata'].max() - eventDict['y'] + eventDict['parameters'] = parameters.copy() + return eventDict + + +def prepareMouseSignal(eventType, button, xData, yData, xPixel, yPixel): + """See Plot documentation for content of events""" + assert eventType in ('mouseMoved', 'mouseClicked', 'mouseDoubleClicked') + assert button in (None, 'left', 'middle', 'right') + + return {'event': eventType, + 'x': xData, + 'y': yData, + 'xpixel': xPixel, + 'ypixel': yPixel, + 'button': button} + + +def prepareHoverSignal(label, type_, posData, posPixel, draggable, selectable): + """See Plot documentation for content of events""" + return {'event': 'hover', + 'label': label, + 'type': type_, + 'x': posData[0], + 'y': posData[1], + 'xpixel': posPixel[0], + 'ypixel': posPixel[1], + 'draggable': draggable, + 'selectable': selectable} + + +def prepareMarkerSignal(eventType, button, label, type_, + draggable, selectable, + posDataMarker, + posPixelCursor=None, posDataCursor=None): + """See Plot documentation for content of events""" + if eventType == 'markerClicked': + assert posPixelCursor is not None + assert posDataCursor is None + + posDataCursor = list(posDataMarker) + if hasattr(posDataCursor[0], "__len__"): + posDataCursor[0] = posDataCursor[0][-1] + if hasattr(posDataCursor[1], "__len__"): + posDataCursor[1] = posDataCursor[1][-1] + + elif eventType == 'markerMoving': + assert posPixelCursor is not None + assert posDataCursor is not None + + elif eventType == 'markerMoved': + assert posPixelCursor is None + assert posDataCursor is None + + posDataCursor = posDataMarker + else: + raise NotImplementedError("Unknown event type {0}".format(eventType)) + + eventDict = {'event': eventType, + 'button': button, + 'label': label, + 'type': type_, + 'x': posDataCursor[0], + 'y': posDataCursor[1], + 'xdata': posDataMarker[0], + 'ydata': posDataMarker[1], + 'draggable': draggable, + 'selectable': selectable} + + if eventType in ('markerMoving', 'markerClicked'): + eventDict['xpixel'] = posPixelCursor[0] + eventDict['ypixel'] = posPixelCursor[1] + + return eventDict + + +def prepareImageSignal(button, label, type_, col, row, + x, y, xPixel, yPixel): + """See Plot documentation for content of events""" + return {'event': 'imageClicked', + 'button': button, + 'label': label, + 'type': type_, + 'col': col, + 'row': row, + 'x': x, + 'y': y, + 'xpixel': xPixel, + 'ypixel': yPixel} + + +def prepareCurveSignal(button, label, type_, xData, yData, + x, y, xPixel, yPixel): + """See Plot documentation for content of events""" + return {'event': 'curveClicked', + 'button': button, + 'label': label, + 'type': type_, + 'xdata': xData, + 'ydata': yData, + 'x': x, + 'y': y, + 'xpixel': xPixel, + 'ypixel': yPixel} + + +def prepareLimitsChangedSignal(sourceObj, xRange, yRange, y2Range): + """See Plot documentation for content of events""" + return {'event': 'limitsChanged', + 'source': id(sourceObj), + 'xdata': xRange, + 'ydata': yRange, + 'y2data': y2Range} diff --git a/silx/gui/plot/PlotInteraction.py b/silx/gui/plot/PlotInteraction.py new file mode 100644 index 0000000..fbc9c1f --- /dev/null +++ b/silx/gui/plot/PlotInteraction.py @@ -0,0 +1,1493 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Implementation of the interaction for the :class:`Plot`.""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "24/01/2017" + + +import math +import numpy +import time +import weakref + +from . import Colors +from . import items +from .Interaction import (ClickOrDrag, LEFT_BTN, RIGHT_BTN, + State, StateMachine) +from .PlotEvents import (prepareCurveSignal, prepareDrawingSignal, + prepareHoverSignal, prepareImageSignal, + prepareMarkerSignal, prepareMouseSignal) + +from .backends.BackendBase import (CURSOR_POINTING, CURSOR_SIZE_HOR, + CURSOR_SIZE_VER, CURSOR_SIZE_ALL) + +from ._utils import (FLOAT32_SAFE_MIN, FLOAT32_MINPOS, FLOAT32_SAFE_MAX, + applyZoomToPlot) + + +# Base class ################################################################## + +class _PlotInteraction(object): + """Base class for interaction handler. + + It provides a weakref to the plot and methods to set/reset overlay. + """ + def __init__(self, plot): + """Init. + + :param plot: The plot to apply modifications to. + """ + self._needReplot = False + self._selectionAreas = set() + self._plot = weakref.ref(plot) # Avoid cyclic-ref + + @property + def plot(self): + plot = self._plot() + assert plot is not None + return plot + + def setSelectionArea(self, points, fill, color, name='', shape='polygon'): + """Set a polygon selection area overlaid on the plot. + Multiple simultaneous areas are supported through the name parameter. + + :param points: The 2D coordinates of the points of the polygon + :type points: An iterable of (x, y) coordinates + :param str fill: The fill mode: 'hatch', 'solid' or 'none' + :param color: RGBA color to use or None to disable display + :type color: list or tuple of 4 float in the range [0, 1] + :param name: The key associated with this selection area + :param str shape: Shape of the area in 'polygon', 'polylines' + """ + assert shape in ('polygon', 'polylines') + + if color is None: + return + + points = numpy.asarray(points) + + # TODO Not very nice, but as is for now + legend = '__SELECTION_AREA__' + name + + fill = fill != 'none' # TODO not very nice either + + self.plot.addItem(points[:, 0], points[:, 1], legend=legend, + replace=False, + shape=shape, color=color, fill=fill, + overlay=True) + self._selectionAreas.add(legend) + + def resetSelectionArea(self): + """Remove all selection areas set by setSelectionArea.""" + for legend in self._selectionAreas: + self.plot.remove(legend, kind='item') + self._selectionAreas = set() + + +# Zoom/Pan #################################################################### + +class _ZoomOnWheel(ClickOrDrag, _PlotInteraction): + """:class:`ClickOrDrag` state machine with zooming on mouse wheel. + + Base class for :class:`Pan` and :class:`Zoom` + """ + class ZoomIdle(ClickOrDrag.Idle): + def onWheel(self, x, y, angle): + scaleF = 1.1 if angle > 0 else 1. / 1.1 + applyZoomToPlot(self.machine.plot, scaleF, (x, y)) + + def __init__(self, plot): + """Init. + + :param plot: The plot to apply modifications to. + """ + _PlotInteraction.__init__(self, plot) + + states = { + 'idle': _ZoomOnWheel.ZoomIdle, + 'rightClick': ClickOrDrag.RightClick, + 'clickOrDrag': ClickOrDrag.ClickOrDrag, + 'drag': ClickOrDrag.Drag + } + StateMachine.__init__(self, states, 'idle') + + +# Pan ######################################################################### + +class Pan(_ZoomOnWheel): + """Pan plot content and zoom on wheel state machine.""" + + def _pixelToData(self, x, y): + xData, yData = self.plot.pixelToData(x, y) + _, y2Data = self.plot.pixelToData(x, y, axis='right') + return xData, yData, y2Data + + def beginDrag(self, x, y): + self._previousDataPos = self._pixelToData(x, y) + + def drag(self, x, y): + xData, yData, y2Data = self._pixelToData(x, y) + lastX, lastY, lastY2 = self._previousDataPos + + xMin, xMax = self.plot.getGraphXLimits() + yMin, yMax = self.plot.getGraphYLimits(axis='left') + y2Min, y2Max = self.plot.getGraphYLimits(axis='right') + + if self.plot.isXAxisLogarithmic(): + try: + dx = math.log10(xData) - math.log10(lastX) + newXMin = pow(10., (math.log10(xMin) - dx)) + newXMax = pow(10., (math.log10(xMax) - dx)) + except (ValueError, OverflowError): + newXMin, newXMax = xMin, xMax + + # Makes sure both values stays in positive float32 range + if newXMin < FLOAT32_MINPOS or newXMax > FLOAT32_SAFE_MAX: + newXMin, newXMax = xMin, xMax + else: + dx = xData - lastX + newXMin, newXMax = xMin - dx, xMax - dx + + # Makes sure both values stays in float32 range + if newXMin < FLOAT32_SAFE_MIN or newXMax > FLOAT32_SAFE_MAX: + newXMin, newXMax = xMin, xMax + + if self.plot.isYAxisLogarithmic(): + try: + dy = math.log10(yData) - math.log10(lastY) + newYMin = pow(10., math.log10(yMin) - dy) + newYMax = pow(10., math.log10(yMax) - dy) + + dy2 = math.log10(y2Data) - math.log10(lastY2) + newY2Min = pow(10., math.log10(y2Min) - dy2) + newY2Max = pow(10., math.log10(y2Max) - dy2) + except (ValueError, OverflowError): + newYMin, newYMax = yMin, yMax + newY2Min, newY2Max = y2Min, y2Max + + # Makes sure y and y2 stays in positive float32 range + if (newYMin < FLOAT32_MINPOS or newYMax > FLOAT32_SAFE_MAX or + newY2Min < FLOAT32_MINPOS or newY2Max > FLOAT32_SAFE_MAX): + newYMin, newYMax = yMin, yMax + newY2Min, newY2Max = y2Min, y2Max + else: + dy = yData - lastY + dy2 = y2Data - lastY2 + newYMin, newYMax = yMin - dy, yMax - dy + newY2Min, newY2Max = y2Min - dy2, y2Max - dy2 + + # Makes sure y and y2 stays in float32 range + if (newYMin < FLOAT32_SAFE_MIN or + newYMax > FLOAT32_SAFE_MAX or + newY2Min < FLOAT32_SAFE_MIN or + newY2Max > FLOAT32_SAFE_MAX): + newYMin, newYMax = yMin, yMax + newY2Min, newY2Max = y2Min, y2Max + + self.plot.setLimits(newXMin, newXMax, + newYMin, newYMax, + newY2Min, newY2Max) + + self._previousDataPos = self._pixelToData(x, y) + + def endDrag(self, startPos, endPos): + del self._previousDataPos + + def cancel(self): + pass + + +# Zoom ######################################################################## + +class Zoom(_ZoomOnWheel): + """Zoom-in/out state machine. + + Zoom-in on selected area, zoom-out on right click, + and zoom on mouse wheel. + """ + _DOUBLE_CLICK_TIMEOUT = 0.4 + + def __init__(self, plot, color): + self.color = color + self.zoomStack = [] + self._lastClick = 0., None + + super(Zoom, self).__init__(plot) + + def _areaWithAspectRatio(self, x0, y0, x1, y1): + _plotLeft, _plotTop, plotW, plotH = self.plot.getPlotBoundsInPixels() + + areaX0, areaY0, areaX1, areaY1 = x0, y0, x1, y1 + + if plotH != 0.: + plotRatio = plotW / float(plotH) + width, height = math.fabs(x1 - x0), math.fabs(y1 - y0) + + if height != 0. and width != 0.: + if width / height > plotRatio: + areaHeight = width / plotRatio + areaX0, areaX1 = x0, x1 + center = 0.5 * (y0 + y1) + areaY0 = center - numpy.sign(y1 - y0) * 0.5 * areaHeight + areaY1 = center + numpy.sign(y1 - y0) * 0.5 * areaHeight + else: + areaWidth = height * plotRatio + areaY0, areaY1 = y0, y1 + center = 0.5 * (x0 + x1) + areaX0 = center - numpy.sign(x1 - x0) * 0.5 * areaWidth + areaX1 = center + numpy.sign(x1 - x0) * 0.5 * areaWidth + + return areaX0, areaY0, areaX1, areaY1 + + def click(self, x, y, btn): + if btn == LEFT_BTN: + lastClickTime, lastClickPos = self._lastClick + + # Signal mouse double clicked event first + if (time.time() - lastClickTime) <= self._DOUBLE_CLICK_TIMEOUT: + # Use position of first click + eventDict = prepareMouseSignal('mouseDoubleClicked', 'left', + *lastClickPos) + self.plot.notify(**eventDict) + + self._lastClick = 0., None + else: + # Signal mouse clicked event + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + eventDict = prepareMouseSignal('mouseClicked', 'left', + dataPos[0], dataPos[1], + x, y) + self.plot.notify(**eventDict) + + self._lastClick = time.time(), (dataPos[0], dataPos[1], x, y) + + # Zoom-in centered on mouse cursor + # xMin, xMax = self.plot.getGraphXLimits() + # yMin, yMax = self.plot.getGraphYLimits() + # y2Min, y2Max = self.plot.getGraphYLimits(axis="right") + # self.zoomStack.append((xMin, xMax, yMin, yMax, y2Min, y2Max)) + # self._zoom(x, y, 2) + elif btn == RIGHT_BTN: + try: + xMin, xMax, yMin, yMax, y2Min, y2Max = self.zoomStack.pop() + except IndexError: + # Signal mouse clicked event + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + eventDict = prepareMouseSignal('mouseClicked', 'right', + dataPos[0], dataPos[1], + x, y) + self.plot.notify(**eventDict) + else: + self.plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max) + + def beginDrag(self, x, y): + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + self.x0, self.y0 = x, y + + def drag(self, x1, y1): + if self.color is None: + return # Do not draw zoom area + + dataPos = self.plot.pixelToData(x1, y1) + assert dataPos is not None + + if self.plot.isKeepDataAspectRatio(): + area = self._areaWithAspectRatio(self.x0, self.y0, x1, y1) + areaX0, areaY0, areaX1, areaY1 = area + areaPoints = ((areaX0, areaY0), + (areaX1, areaY0), + (areaX1, areaY1), + (areaX0, areaY1)) + areaPoints = numpy.array([self.plot.pixelToData( + x, y, check=False) for (x, y) in areaPoints]) + + if self.color != 'video inverted': + areaColor = list(self.color) + areaColor[3] *= 0.25 + else: + areaColor = [1., 1., 1., 1.] + + self.setSelectionArea(areaPoints, + fill='none', + color=areaColor, + name="zoomedArea") + + corners = ((self.x0, self.y0), + (self.x0, y1), + (x1, y1), + (x1, self.y0)) + corners = numpy.array([self.plot.pixelToData(x, y, check=False) + for (x, y) in corners]) + + self.setSelectionArea(corners, fill='none', color=self.color) + + def endDrag(self, startPos, endPos): + x0, y0 = startPos + x1, y1 = endPos + + if x0 != x1 or y0 != y1: # Avoid empty zoom area + # Store current zoom state in stack + xMin, xMax = self.plot.getGraphXLimits() + yMin, yMax = self.plot.getGraphYLimits() + y2Min, y2Max = self.plot.getGraphYLimits(axis="right") + self.zoomStack.append((xMin, xMax, yMin, yMax, y2Min, y2Max)) + + if self.plot.isKeepDataAspectRatio(): + x0, y0, x1, y1 = self._areaWithAspectRatio(x0, y0, x1, y1) + + # Convert to data space and set limits + x0, y0 = self.plot.pixelToData(x0, y0, check=False) + + dataPos = self.plot.pixelToData( + startPos[0], startPos[1], axis="right", check=False) + y2_0 = dataPos[1] + + x1, y1 = self.plot.pixelToData(x1, y1, check=False) + + dataPos = self.plot.pixelToData( + endPos[0], endPos[1], axis="right", check=False) + y2_1 = dataPos[1] + + xMin, xMax = min(x0, x1), max(x0, x1) + yMin, yMax = min(y0, y1), max(y0, y1) + y2Min, y2Max = min(y2_0, y2_1), max(y2_0, y2_1) + + self.plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max) + + self.resetSelectionArea() + + def cancel(self): + if isinstance(self.state, self.states['drag']): + self.resetSelectionArea() + + +# Select ###################################################################### + +class Select(StateMachine, _PlotInteraction): + """Base class for drawing selection areas.""" + + def __init__(self, plot, parameters, states, state): + """Init a state machine. + + :param plot: The plot to apply changes to. + :param dict parameters: A dict of parameters such as color. + :param dict states: The states of the state machine. + :param str state: The name of the initial state. + """ + _PlotInteraction.__init__(self, plot) + self.parameters = parameters + StateMachine.__init__(self, states, state) + + def onWheel(self, x, y, angle): + scaleF = 1.1 if angle > 0 else 1. / 1.1 + applyZoomToPlot(self.plot, scaleF, (x, y)) + + @property + def color(self): + return self.parameters.get('color', None) + + +class SelectPolygon(Select): + """Drawing selection polygon area state machine.""" + + DRAG_THRESHOLD_DIST = 4 + + class Idle(State): + def onPress(self, x, y, btn): + if btn == LEFT_BTN: + self.goto('select', x, y) + return True + + class Select(State): + def enterState(self, x, y): + dataPos = self.machine.plot.pixelToData(x, y) + assert dataPos is not None + self._firstPos = dataPos + self.points = [dataPos, dataPos] + + self.updateFirstPoint() + + def updateFirstPoint(self): + """Update drawing first point, using self._firstPos""" + x, y = self.machine.plot.dataToPixel(*self._firstPos, check=False) + + offset = self.machine.DRAG_THRESHOLD_DIST + points = [(x - offset, y - offset), + (x - offset, y + offset), + (x + offset, y + offset), + (x + offset, y - offset)] + points = [self.machine.plot.pixelToData(xpix, ypix, check=False) + for xpix, ypix in points] + self.machine.setSelectionArea(points, fill=None, + color=self.machine.color, + name='first_point') + + def updateSelectionArea(self): + """Update drawing selection area using self.points""" + self.machine.setSelectionArea(self.points, + fill='hatch', + color=self.machine.color) + eventDict = prepareDrawingSignal('drawingProgress', + 'polygon', + self.points, + self.machine.parameters) + self.machine.plot.notify(**eventDict) + + def onWheel(self, x, y, angle): + self.machine.onWheel(x, y, angle) + self.updateFirstPoint() + + def onRelease(self, x, y, btn): + if btn == LEFT_BTN: + # checking if the position is close to the first point + # if yes : closing the "loop" + firstPos = self.machine.plot.dataToPixel(*self._firstPos, + check=False) + dx, dy = abs(firstPos[0] - x), abs(firstPos[1] - y) + + # Only allow to close polygon after first point + if (len(self.points) > 2 and + dx < self.machine.DRAG_THRESHOLD_DIST and + dy < self.machine.DRAG_THRESHOLD_DIST): + self.machine.resetSelectionArea() + + self.points[-1] = self.points[0] + + eventDict = prepareDrawingSignal('drawingFinished', + 'polygon', + self.points, + self.machine.parameters) + self.machine.plot.notify(**eventDict) + self.goto('idle') + return False + + # Update polygon last point not too close to previous one + dataPos = self.machine.plot.pixelToData(x, y) + assert dataPos is not None + self.updateSelectionArea() + + # checking that the new points isnt the same (within range) + # of the previous one + # This has to be done because sometimes the mouse release event + # is caught right after entering the Select state (i.e : press + # in Idle state, but with a slightly different position that + # the mouse press. So we had the two first vertices that were + # almost identical. + previousPos = self.machine.plot.dataToPixel(*self.points[-2], + check=False) + dx, dy = abs(previousPos[0] - x), abs(previousPos[1] - y) + if(dx >= self.machine.DRAG_THRESHOLD_DIST or + dy >= self.machine.DRAG_THRESHOLD_DIST): + self.points.append(dataPos) + else: + self.points[-1] = dataPos + + return True + + elif btn == RIGHT_BTN: + self.machine.resetSelectionArea() + + firstPos = self.machine.plot.dataToPixel(*self._firstPos, + check=False) + dx, dy = abs(firstPos[0] - x), abs(firstPos[1] - y) + + if (dx < self.machine.DRAG_THRESHOLD_DIST and + dy < self.machine.DRAG_THRESHOLD_DIST): + self.points[-1] = self.points[0] + else: + dataPos = self.machine.plot.pixelToData(x, y) + assert dataPos is not None + self.points[-1] = dataPos + if self.points[-2] == self.points[-1]: + self.points.pop() + self.points.append(self.points[0]) + + eventDict = prepareDrawingSignal('drawingFinished', + 'polygon', + self.points, + self.machine.parameters) + self.machine.plot.notify(**eventDict) + self.goto('idle') + return False + + return False + + def onMove(self, x, y): + firstPos = self.machine.plot.dataToPixel(*self._firstPos, + check=False) + dx, dy = abs(firstPos[0] - x), abs(firstPos[1] - y) + if (dx < self.machine.DRAG_THRESHOLD_DIST and + dy < self.machine.DRAG_THRESHOLD_DIST): + x, y = firstPos # Snap to first point + + dataPos = self.machine.plot.pixelToData(x, y) + assert dataPos is not None + self.points[-1] = dataPos + self.updateSelectionArea() + + def __init__(self, plot, parameters): + states = { + 'idle': SelectPolygon.Idle, + 'select': SelectPolygon.Select + } + super(SelectPolygon, self).__init__(plot, parameters, + states, 'idle') + + def cancel(self): + if isinstance(self.state, self.states['select']): + self.resetSelectionArea() + + +class Select2Points(Select): + """Base class for drawing selection based on 2 input points.""" + class Idle(State): + def onPress(self, x, y, btn): + if btn == LEFT_BTN: + self.goto('start', x, y) + return True + + class Start(State): + def enterState(self, x, y): + self.machine.beginSelect(x, y) + + def onMove(self, x, y): + self.goto('select', x, y) + + def onRelease(self, x, y, btn): + if btn == LEFT_BTN: + self.goto('select', x, y) + return True + + class Select(State): + def enterState(self, x, y): + self.onMove(x, y) + + def onMove(self, x, y): + self.machine.select(x, y) + + def onRelease(self, x, y, btn): + if btn == LEFT_BTN: + self.machine.endSelect(x, y) + self.goto('idle') + + def __init__(self, plot, parameters): + states = { + 'idle': Select2Points.Idle, + 'start': Select2Points.Start, + 'select': Select2Points.Select + } + super(Select2Points, self).__init__(plot, parameters, + states, 'idle') + + def beginSelect(self, x, y): + pass + + def select(self, x, y): + pass + + def endSelect(self, x, y): + pass + + def cancelSelect(self): + pass + + def cancel(self): + if isinstance(self.state, self.states['select']): + self.cancelSelect() + + +class SelectRectangle(Select2Points): + """Drawing rectangle selection area state machine.""" + def beginSelect(self, x, y): + self.startPt = self.plot.pixelToData(x, y) + assert self.startPt is not None + + def select(self, x, y): + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + + self.setSelectionArea((self.startPt, + (self.startPt[0], dataPos[1]), + dataPos, + (dataPos[0], self.startPt[1])), + fill='hatch', + color=self.color) + + eventDict = prepareDrawingSignal('drawingProgress', + 'rectangle', + (self.startPt, dataPos), + self.parameters) + self.plot.notify(**eventDict) + + def endSelect(self, x, y): + self.resetSelectionArea() + + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + + eventDict = prepareDrawingSignal('drawingFinished', + 'rectangle', + (self.startPt, dataPos), + self.parameters) + self.plot.notify(**eventDict) + + def cancelSelect(self): + self.resetSelectionArea() + + +class SelectLine(Select2Points): + """Drawing line selection area state machine.""" + def beginSelect(self, x, y): + self.startPt = self.plot.pixelToData(x, y) + assert self.startPt is not None + + def select(self, x, y): + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + + self.setSelectionArea((self.startPt, dataPos), + fill='hatch', + color=self.color) + + eventDict = prepareDrawingSignal('drawingProgress', + 'line', + (self.startPt, dataPos), + self.parameters) + self.plot.notify(**eventDict) + + def endSelect(self, x, y): + self.resetSelectionArea() + + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + + eventDict = prepareDrawingSignal('drawingFinished', + 'line', + (self.startPt, dataPos), + self.parameters) + self.plot.notify(**eventDict) + + def cancelSelect(self): + self.resetSelectionArea() + + +class Select1Point(Select): + """Base class for drawing selection area based on one input point.""" + class Idle(State): + def onPress(self, x, y, btn): + if btn == LEFT_BTN: + self.goto('select', x, y) + return True + + class Select(State): + def enterState(self, x, y): + self.onMove(x, y) + + def onMove(self, x, y): + self.machine.select(x, y) + + def onRelease(self, x, y, btn): + if btn == LEFT_BTN: + self.machine.endSelect(x, y) + self.goto('idle') + + def onWheel(self, x, y, angle): + self.machine.onWheel(x, y, angle) # Call select default wheel + self.machine.select(x, y) + + def __init__(self, plot, parameters): + states = { + 'idle': Select1Point.Idle, + 'select': Select1Point.Select + } + super(Select1Point, self).__init__(plot, parameters, states, 'idle') + + def select(self, x, y): + pass + + def endSelect(self, x, y): + pass + + def cancelSelect(self): + pass + + def cancel(self): + if isinstance(self.state, self.states['select']): + self.cancelSelect() + + +class SelectHLine(Select1Point): + """Drawing a horizontal line selection area state machine.""" + def _hLine(self, y): + """Return points in data coords of the segment visible in the plot. + + Supports non-orthogonal axes. + """ + left, _top, width, _height = self.plot.getPlotBoundsInPixels() + + dataPos1 = self.plot.pixelToData(left, y, check=False) + dataPos2 = self.plot.pixelToData(left + width, y, check=False) + return dataPos1, dataPos2 + + def select(self, x, y): + points = self._hLine(y) + self.setSelectionArea(points, fill='hatch', color=self.color) + + eventDict = prepareDrawingSignal('drawingProgress', + 'hline', + points, + self.parameters) + self.plot.notify(**eventDict) + + def endSelect(self, x, y): + self.resetSelectionArea() + + eventDict = prepareDrawingSignal('drawingFinished', + 'hline', + self._hLine(y), + self.parameters) + self.plot.notify(**eventDict) + + def cancelSelect(self): + self.resetSelectionArea() + + +class SelectVLine(Select1Point): + """Drawing a vertical line selection area state machine.""" + def _vLine(self, x): + """Return points in data coords of the segment visible in the plot. + + Supports non-orthogonal axes. + """ + _left, top, _width, height = self.plot.getPlotBoundsInPixels() + + dataPos1 = self.plot.pixelToData(x, top, check=False) + dataPos2 = self.plot.pixelToData(x, top + height, check=False) + return dataPos1, dataPos2 + + def select(self, x, y): + points = self._vLine(x) + self.setSelectionArea(points, fill='hatch', color=self.color) + + eventDict = prepareDrawingSignal('drawingProgress', + 'vline', + points, + self.parameters) + self.plot.notify(**eventDict) + + def endSelect(self, x, y): + self.resetSelectionArea() + + eventDict = prepareDrawingSignal('drawingFinished', + 'vline', + self._vLine(x), + self.parameters) + self.plot.notify(**eventDict) + + def cancelSelect(self): + self.resetSelectionArea() + + +class DrawFreeHand(Select): + """Interaction for drawing pencil. It display the preview of the pencil + before pressing the mouse. + """ + + class Idle(State): + def onPress(self, x, y, btn): + if btn == LEFT_BTN: + self.goto('select', x, y) + return True + + def onMove(self, x, y): + self.machine.updatePencilShape(x, y) + + def onLeave(self): + self.machine.cancel() + + class Select(State): + def enterState(self, x, y): + self.__isOut = False + self.machine.setFirstPoint(x, y) + + def onMove(self, x, y): + self.machine.updatePencilShape(x, y) + self.machine.select(x, y) + + def onRelease(self, x, y, btn): + if btn == LEFT_BTN: + if self.__isOut: + self.machine.resetSelectionArea() + self.machine.endSelect(x, y) + self.goto('idle') + + def onEnter(self): + self.__isOut = False + + def onLeave(self): + self.__isOut = True + + def __init__(self, plot, parameters): + # Circle used for pencil preview + angle = numpy.arange(13.) * numpy.pi * 2.0 / 13. + size = parameters.get('width', 1.) * 0.5 + self._circle = size * numpy.array((numpy.cos(angle), + numpy.sin(angle))).T + + states = { + 'idle': DrawFreeHand.Idle, + 'select': DrawFreeHand.Select + } + super(DrawFreeHand, self).__init__(plot, parameters, states, 'idle') + + @property + def width(self): + return self.parameters.get('width', None) + + def setFirstPoint(self, x, y): + self._points = [] + self.select(x, y) + + def updatePencilShape(self, x, y): + center = self.plot.pixelToData(x, y, check=False) + assert center is not None + + polygon = center + self._circle + + self.setSelectionArea(polygon, fill='none', color=self.color) + + def select(self, x, y): + pos = self.plot.pixelToData(x, y, check=False) + if len(self._points) > 0: + if self._points[-1] == pos: + # Skip same points + return + self._points.append(pos) + eventDict = prepareDrawingSignal('drawingProgress', + 'polylines', + self._points, + self.parameters) + self.plot.notify(**eventDict) + + def endSelect(self, x, y): + pos = self.plot.pixelToData(x, y, check=False) + if len(self._points) > 0: + if self._points[-1] != pos: + # Append if different + self._points.append(pos) + + eventDict = prepareDrawingSignal('drawingFinished', + 'polylines', + self._points, + self.parameters) + self.plot.notify(**eventDict) + self._points = None + + def cancelSelect(self): + self.resetSelectionArea() + + def cancel(self): + self.resetSelectionArea() + + +class SelectFreeLine(ClickOrDrag, _PlotInteraction): + """Base class for drawing free lines with tools such as pencil.""" + + def __init__(self, plot, parameters): + """Init a state machine. + + :param plot: The plot to apply changes to. + :param dict parameters: A dict of parameters such as color. + """ + # self.DRAG_THRESHOLD_SQUARE_DIST = 1 # Disable first move threshold + self._points = [] + ClickOrDrag.__init__(self) + _PlotInteraction.__init__(self, plot) + self.parameters = parameters + + def onWheel(self, x, y, angle): + scaleF = 1.1 if angle > 0 else 1. / 1.1 + applyZoomToPlot(self.plot, scaleF, (x, y)) + + @property + def color(self): + return self.parameters.get('color', None) + + def click(self, x, y, btn): + if btn == LEFT_BTN: + self._processEvent(x, y, isLast=True) + + def beginDrag(self, x, y): + self._processEvent(x, y, isLast=False) + + def drag(self, x, y): + self._processEvent(x, y, isLast=False) + + def endDrag(self, startPos, endPos): + x, y = endPos + self._processEvent(x, y, isLast=True) + + def cancel(self): + self.resetSelectionArea() + self._points = [] + + def _processEvent(self, x, y, isLast): + dataPos = self.plot.pixelToData(x, y, check=False) + isNewPoint = not self._points or dataPos != self._points[-1] + + if isNewPoint: + self._points.append(dataPos) + + if isNewPoint or isLast: + eventDict = prepareDrawingSignal( + 'drawingFinished' if isLast else 'drawingProgress', + 'polylines', + self._points, + self.parameters) + self.plot.notify(**eventDict) + + if not isLast: + self.setSelectionArea(self._points, fill='none', color=self.color, + shape='polylines') + else: + self.cancel() + + +# ItemInteraction ############################################################# + +class ItemsInteraction(ClickOrDrag, _PlotInteraction): + """Interaction with items (markers, curves and images). + + This class provides selection and dragging of plot primitives + that support those interaction. + It is also meant to be combined with the zoom interaction. + """ + + class Idle(ClickOrDrag.Idle): + def __init__(self, *args, **kw): + super(ItemsInteraction.Idle, self).__init__(*args, **kw) + self._hoverMarker = None + + def onWheel(self, x, y, angle): + scaleF = 1.1 if angle > 0 else 1. / 1.1 + applyZoomToPlot(self.machine.plot, scaleF, (x, y)) + + def onMove(self, x, y): + marker = self.machine.plot._pickMarker(x, y) + if marker is not None: + dataPos = self.machine.plot.pixelToData(x, y) + assert dataPos is not None + eventDict = prepareHoverSignal( + marker.getLegend(), 'marker', + dataPos, (x, y), + marker.isDraggable(), + marker.isSelectable()) + self.machine.plot.notify(**eventDict) + + if marker != self._hoverMarker: + self._hoverMarker = marker + + if marker is None: + self.machine.plot.setGraphCursorShape() + + elif marker.isDraggable(): + if isinstance(marker, items.YMarker): + self.machine.plot.setGraphCursorShape(CURSOR_SIZE_VER) + elif isinstance(marker, items.XMarker): + self.machine.plot.setGraphCursorShape(CURSOR_SIZE_HOR) + else: + self.machine.plot.setGraphCursorShape(CURSOR_SIZE_ALL) + + elif marker.isSelectable(): + self.machine.plot.setGraphCursorShape(CURSOR_POINTING) + + return True + + def __init__(self, plot): + _PlotInteraction.__init__(self, plot) + + states = { + 'idle': ItemsInteraction.Idle, + 'rightClick': ClickOrDrag.RightClick, + 'clickOrDrag': ClickOrDrag.ClickOrDrag, + 'drag': ClickOrDrag.Drag + } + StateMachine.__init__(self, states, 'idle') + + def click(self, x, y, btn): + """Handle mouse click + + :param x: X position of the mouse in pixels + :param y: Y position of the mouse in pixels + :param btn: Pressed button id + :return: True if click is catched by an item, False otherwise + """ + # Signal mouse clicked event + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + eventDict = prepareMouseSignal('mouseClicked', btn, + dataPos[0], dataPos[1], + x, y) + self.plot.notify(**eventDict) + + eventDict = self._handleClick(x, y, btn) + if eventDict is not None: + self.plot.notify(**eventDict) + + def _handleClick(self, x, y, btn): + """Perform picking and prepare event if click is handled here + + :param x: X position of the mouse in pixels + :param y: Y position of the mouse in pixels + :param btn: Pressed button id + :return: event description to send of None if not handling event. + :rtype: dict or None + """ + + if btn == LEFT_BTN: + marker = self.plot._pickMarker( + x, y, lambda m: m.isSelectable()) + if marker is not None: + xData, yData = marker.getPosition() + if xData is None: + xData = [0, 1] + if yData is None: + yData = [0, 1] + + eventDict = prepareMarkerSignal('markerClicked', + 'left', + marker.getLegend(), + 'marker', + marker.isDraggable(), + marker.isSelectable(), + (xData, yData), + (x, y), None) + return eventDict + + else: + picked = self.plot._pickImageOrCurve( + x, y, lambda item: item.isSelectable()) + + if picked is None: + pass + + elif picked[0] == 'curve': + curve = picked[1] + + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + + eventDict = prepareCurveSignal('left', + curve.getLegend(), + 'curve', + picked[2], picked[3], + dataPos[0], dataPos[1], + x, y) + return eventDict + + elif picked[0] == 'image': + image = picked[1] + + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + + # Get corresponding coordinate in image + origin = image.getOrigin() + scale = image.getScale() + column = int((dataPos[0] - origin[0]) / float(scale[0])) + row = int((dataPos[1] - origin[1]) / float(scale[1])) + + eventDict = prepareImageSignal('left', + image.getLegend(), + 'image', + column, row, + dataPos[0], dataPos[1], + x, y) + return eventDict + + return None + + def _signalMarkerMovingEvent(self, eventType, marker, x, y): + assert marker is not None + + xData, yData = marker.getPosition() + if xData is None: + xData = [0, 1] + if yData is None: + yData = [0, 1] + + posDataCursor = self.plot.pixelToData(x, y) + assert posDataCursor is not None + + eventDict = prepareMarkerSignal(eventType, + 'left', + marker.getLegend(), + 'marker', + marker.isDraggable(), + marker.isSelectable(), + (xData, yData), + (x, y), + posDataCursor) + self.plot.notify(**eventDict) + + def beginDrag(self, x, y): + """Handle begining of drag interaction + + :param x: X position of the mouse in pixels + :param y: Y position of the mouse in pixels + :return: True if drag is catched by an item, False otherwise + """ + self._lastPos = self.plot.pixelToData(x, y) + assert self._lastPos is not None + + self.imageLegend = None + self.markerLegend = None + marker = self.plot._pickMarker( + x, y, lambda m: m.isDraggable()) + + if marker is not None: + self.markerLegend = marker.getLegend() + self._signalMarkerMovingEvent('markerMoving', marker, x, y) + else: + picked = self.plot._pickImageOrCurve( + x, + y, + lambda item: + hasattr(item, 'isDraggable') and item.isDraggable()) + if picked is None: + self.imageLegend = None + self.plot.setGraphCursorShape() + return False + else: + assert picked[0] == 'image' # For now only drag images + self.imageLegend = picked[1].getLegend() + return True + + def drag(self, x, y): + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + xData, yData = dataPos + + if self.markerLegend is not None: + marker = self.plot._getMarker(self.markerLegend) + if marker is not None: + marker.setPosition(xData, yData) + + self._signalMarkerMovingEvent( + 'markerMoving', marker, x, y) + + if self.imageLegend is not None: + image = self.plot.getImage(self.imageLegend) + origin = image.getOrigin() + xImage = origin[0] + xData - self._lastPos[0] + yImage = origin[1] + yData - self._lastPos[1] + image.setOrigin((xImage, yImage)) + + self._lastPos = xData, yData + + def endDrag(self, startPos, endPos): + if self.markerLegend is not None: + marker = self.plot._getMarker(self.markerLegend) + posData = list(marker.getPosition()) + if posData[0] is None: + posData[0] = [0, 1] + if posData[1] is None: + posData[1] = [0, 1] + + eventDict = prepareMarkerSignal( + 'markerMoved', + 'left', + marker.getLegend(), + 'marker', + marker.isDraggable(), + marker.isSelectable(), + posData) + self.plot.notify(**eventDict) + + self.plot.setGraphCursorShape() + + del self.markerLegend + del self.imageLegend + del self._lastPos + + def cancel(self): + self.plot.setGraphCursorShape() + + +# FocusManager ################################################################ + +class FocusManager(StateMachine): + """Manages focus across multiple event handlers + + On press an event handler can acquire focus. + By default it looses focus when all buttons are released. + """ + class Idle(State): + def onPress(self, x, y, btn): + for eventHandler in self.machine.eventHandlers: + requestFocus = eventHandler.handleEvent('press', x, y, btn) + if requestFocus: + self.goto('focus', eventHandler, btn) + break + + def _processEvent(self, *args): + for eventHandler in self.machine.eventHandlers: + consumeEvent = eventHandler.handleEvent(*args) + if consumeEvent: + break + + def onMove(self, x, y): + self._processEvent('move', x, y) + + def onRelease(self, x, y, btn): + self._processEvent('release', x, y, btn) + + def onWheel(self, x, y, angle): + self._processEvent('wheel', x, y, angle) + + class Focus(State): + def enterState(self, eventHandler, btn): + self.eventHandler = eventHandler + self.focusBtns = {btn} + + def onPress(self, x, y, btn): + self.focusBtns.add(btn) + self.eventHandler.handleEvent('press', x, y, btn) + + def onMove(self, x, y): + self.eventHandler.handleEvent('move', x, y) + + def onRelease(self, x, y, btn): + self.focusBtns.discard(btn) + requestFocus = self.eventHandler.handleEvent('release', x, y, btn) + if len(self.focusBtns) == 0 and not requestFocus: + self.goto('idle') + + def onWheel(self, x, y, angleInDegrees): + self.eventHandler.handleEvent('wheel', x, y, angleInDegrees) + + def __init__(self, eventHandlers=()): + self.eventHandlers = list(eventHandlers) + + states = { + 'idle': FocusManager.Idle, + 'focus': FocusManager.Focus + } + super(FocusManager, self).__init__(states, 'idle') + + def cancel(self): + for handler in self.eventHandlers: + handler.cancel() + + +class ZoomAndSelect(ItemsInteraction): + """Combine Zoom and ItemInteraction state machine. + + :param plot: The Plot to which this interaction is attached + :param color: The color to use for the zoom area bounding box + """ + + def __init__(self, plot, color): + super(ZoomAndSelect, self).__init__(plot) + self._zoom = Zoom(plot, color) + self._doZoom = False + + @property + def color(self): + """Color of the zoom area""" + return self._zoom.color + + def click(self, x, y, btn): + """Handle mouse click + + :param x: X position of the mouse in pixels + :param y: Y position of the mouse in pixels + :param btn: Pressed button id + :return: True if click is catched by an item, False otherwise + """ + eventDict = self._handleClick(x, y, btn) + + if eventDict is not None: + # Signal mouse clicked event + dataPos = self.plot.pixelToData(x, y) + assert dataPos is not None + clickedEventDict = prepareMouseSignal('mouseClicked', btn, + dataPos[0], dataPos[1], + x, y) + self.plot.notify(**clickedEventDict) + + self.plot.notify(**eventDict) + + else: + self._zoom.click(x, y, btn) + + def beginDrag(self, x, y): + """Handle start drag and switching between zoom and item drag. + + :param x: X position in pixels + :param y: Y position in pixels + """ + self._doZoom = not super(ZoomAndSelect, self).beginDrag(x, y) + if self._doZoom: + self._zoom.beginDrag(x, y) + + def drag(self, x, y): + """Handle drag, eventually forwarding to zoom. + + :param x: X position in pixels + :param y: Y position in pixels + """ + if self._doZoom: + return self._zoom.drag(x, y) + else: + return super(ZoomAndSelect, self).drag(x, y) + + def endDrag(self, startPos, endPos): + """Handle end of drag, eventually forwarding to zoom. + + :param startPos: (x, y) position at the beginning of the drag + :param endPos: (x, y) position at the end of the drag + """ + if self._doZoom: + return self._zoom.endDrag(startPos, endPos) + else: + return super(ZoomAndSelect, self).endDrag(startPos, endPos) + + +# Interaction mode control #################################################### + +class PlotInteraction(object): + """Proxy to currently use state machine for interaction. + + This allows to switch interactive mode. + + :param plot: The :class:`Plot` to apply interaction to + """ + + _DRAW_MODES = { + 'polygon': SelectPolygon, + 'rectangle': SelectRectangle, + 'line': SelectLine, + 'vline': SelectVLine, + 'hline': SelectHLine, + 'polylines': SelectFreeLine, + 'pencil': DrawFreeHand, + } + + def __init__(self, plot): + self._plot = weakref.ref(plot) # Avoid cyclic-ref + + self.zoomOnWheel = True + """True to enable zoom on wheel, False otherwise.""" + + # Default event handler + self._eventHandler = ItemsInteraction(plot) + + def getInteractiveMode(self): + """Returns the current interactive mode as a dict. + + The returned dict contains at least the key 'mode'. + Mode can be: 'draw', 'pan', 'select', 'zoom'. + It can also contains extra keys (e.g., 'color') specific to a mode + as provided to :meth:`setInteractiveMode`. + """ + if isinstance(self._eventHandler, ZoomAndSelect): + return {'mode': 'zoom', 'color': self._eventHandler.color} + + elif isinstance(self._eventHandler, Select): + result = self._eventHandler.parameters.copy() + result['mode'] = 'draw' + return result + + elif isinstance(self._eventHandler, Pan): + return {'mode': 'pan'} + + else: + return {'mode': 'select'} + + def setInteractiveMode(self, mode, color='black', + shape='polygon', label=None, width=None): + """Switch the interactive mode. + + :param str mode: The name of the interactive mode. + In 'draw', 'pan', 'select', 'zoom'. + :param color: Only for 'draw' and 'zoom' modes. + Color to use for drawing selection area. Default black. + If None, selection area is not drawn. + :type color: Color description: The name as a str or + a tuple of 4 floats or None. + :param str shape: Only for 'draw' mode. The kind of shape to draw. + In 'polygon', 'rectangle', 'line', 'vline', 'hline', + 'polylines'. + Default is 'polygon'. + :param str label: Only for 'draw' mode. + :param float width: Width of the pencil. Only for draw pencil mode. + """ + assert mode in ('draw', 'pan', 'select', 'zoom') + + plot = self._plot() + assert plot is not None + + if color not in (None, 'video inverted'): + color = Colors.rgba(color) + + if mode == 'draw': + assert shape in self._DRAW_MODES + eventHandlerClass = self._DRAW_MODES[shape] + parameters = { + 'shape': shape, + 'label': label, + 'color': color, + 'width': width, + } + + self._eventHandler.cancel() + self._eventHandler = eventHandlerClass(plot, parameters) + + elif mode == 'pan': + # Ignores color, shape and label + self._eventHandler.cancel() + self._eventHandler = Pan(plot) + + elif mode == 'zoom': + # Ignores shape and label + self._eventHandler.cancel() + self._eventHandler = ZoomAndSelect(plot, color) + + else: # Default mode: interaction with plot objects + # Ignores color, shape and label + self._eventHandler.cancel() + self._eventHandler = ItemsInteraction(plot) + + def handleEvent(self, event, *args, **kwargs): + """Forward event to current interactive mode state machine.""" + if not self.zoomOnWheel and event == 'wheel': + return # Discard wheel events + self._eventHandler.handleEvent(event, *args, **kwargs) diff --git a/silx/gui/plot/PlotToolButtons.py b/silx/gui/plot/PlotToolButtons.py new file mode 100644 index 0000000..8042391 --- /dev/null +++ b/silx/gui/plot/PlotToolButtons.py @@ -0,0 +1,280 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a set of QToolButton to use with :class:`.PlotWidget`. + +The following QToolButton are available: + +- :class:`AspectToolButton` +- :class:`YAxisOriginToolButton` +- :class:`ProfileToolButton` + +""" + +__authors__ = ["V. Valls", "H. Payno"] +__license__ = "MIT" +__date__ = "26/01/2017" + + +import logging +from .. import icons +from .. import qt + + +_logger = logging.getLogger(__name__) + + +class PlotToolButton(qt.QToolButton): + """A QToolButton connected to a :class:`.PlotWidget`. + """ + + def __init__(self, parent=None, plot=None): + super(PlotToolButton, self).__init__(parent) + self._plot = None + if plot is not None: + self.setPlot(plot) + + def plot(self): + """ + Returns the plot connected to the widget. + """ + return self._plot + + def setPlot(self, plot): + """ + Set the plot connected to the widget + + :param plot: :class:`.PlotWidget` instance on which to operate. + """ + if self._plot is plot: + return + if self._plot is not None: + self._disconnectPlot(self._plot) + self._plot = plot + if self._plot is not None: + self._connectPlot(self._plot) + + def _connectPlot(self, plot): + """ + Called when the plot is connected to the widget + + :param plot: :class:`.PlotWidget` instance + """ + pass + + def _disconnectPlot(self, plot): + """ + Called when the plot is disconnected from the widget + + :param plot: :class:`.PlotWidget` instance + """ + pass + + +class AspectToolButton(PlotToolButton): + + STATE = None + """Lazy loaded states used to feed AspectToolButton""" + + def __init__(self, parent=None, plot=None): + if self.STATE is None: + self.STATE = {} + # dont keep ratio + self.STATE[False, "icon"] = icons.getQIcon('shape-ellipse-solid') + self.STATE[False, "state"] = "Aspect ratio is not kept" + self.STATE[False, "action"] = "Do no keep data aspect ratio" + # keep ratio + self.STATE[True, "icon"] = icons.getQIcon('shape-circle-solid') + self.STATE[True, "state"] = "Aspect ratio is kept" + self.STATE[True, "action"] = "Keep data aspect ratio" + + super(AspectToolButton, self).__init__(parent=parent, plot=plot) + + keepAction = self._createAction(True) + keepAction.triggered.connect(self.keepDataAspectRatio) + keepAction.setIconVisibleInMenu(True) + + dontKeepAction = self._createAction(False) + dontKeepAction.triggered.connect(self.dontKeepDataAspectRatio) + dontKeepAction.setIconVisibleInMenu(True) + + menu = qt.QMenu(self) + menu.addAction(keepAction) + menu.addAction(dontKeepAction) + self.setMenu(menu) + self.setPopupMode(qt.QToolButton.InstantPopup) + + def _createAction(self, keepAspectRatio): + icon = self.STATE[keepAspectRatio, "icon"] + text = self.STATE[keepAspectRatio, "action"] + return qt.QAction(icon, text, self) + + def _connectPlot(self, plot): + plot.sigSetKeepDataAspectRatio.connect(self._keepDataAspectRatioChanged) + self._keepDataAspectRatioChanged(plot.isKeepDataAspectRatio()) + + def _disconnectPlot(self, plot): + plot.sigSetKeepDataAspectRatio.disconnect(self._keepDataAspectRatioChanged) + + def keepDataAspectRatio(self): + """Configure the plot to keep the aspect ratio""" + plot = self.plot() + if plot is not None: + # This will trigger _keepDataAspectRatioChanged + plot.setKeepDataAspectRatio(True) + + def dontKeepDataAspectRatio(self): + """Configure the plot to not keep the aspect ratio""" + plot = self.plot() + if plot is not None: + # This will trigger _keepDataAspectRatioChanged + plot.setKeepDataAspectRatio(False) + + def _keepDataAspectRatioChanged(self, aspectRatio): + """Handle Plot set keep aspect ratio signal""" + icon, toolTip = self.STATE[aspectRatio, "icon"], self.STATE[aspectRatio, "state"] + self.setIcon(icon) + self.setToolTip(toolTip) + + +class YAxisOriginToolButton(PlotToolButton): + + STATE = None + """Lazy loaded states used to feed YAxisOriginToolButton""" + + def __init__(self, parent=None, plot=None): + if self.STATE is None: + self.STATE = {} + # is down + self.STATE[False, "icon"] = icons.getQIcon('plot-ydown') + self.STATE[False, "state"] = "Y-axis is oriented downward" + self.STATE[False, "action"] = "Orient Y-axis downward" + # keep ration + self.STATE[True, "icon"] = icons.getQIcon('plot-yup') + self.STATE[True, "state"] = "Y-axis is oriented upward" + self.STATE[True, "action"] = "Orient Y-axis upward" + + super(YAxisOriginToolButton, self).__init__(parent=parent, plot=plot) + + upwardAction = self._createAction(True) + upwardAction.triggered.connect(self.setYAxisUpward) + upwardAction.setIconVisibleInMenu(True) + + downwardAction = self._createAction(False) + downwardAction.triggered.connect(self.setYAxisDownward) + downwardAction.setIconVisibleInMenu(True) + + menu = qt.QMenu(self) + menu.addAction(upwardAction) + menu.addAction(downwardAction) + self.setMenu(menu) + self.setPopupMode(qt.QToolButton.InstantPopup) + + def _createAction(self, isUpward): + icon = self.STATE[isUpward, "icon"] + text = self.STATE[isUpward, "action"] + return qt.QAction(icon, text, self) + + def _connectPlot(self, plot): + plot.sigSetYAxisInverted.connect(self._yAxisInvertedChanged) + self._yAxisInvertedChanged(plot.isYAxisInverted()) + + def _disconnectPlot(self, plot): + plot.sigSetYAxisInverted.disconnect(self._yAxisInvertedChanged) + + def setYAxisUpward(self): + """Configure the plot to use y-axis upward""" + plot = self.plot() + if plot is not None: + # This will trigger _yAxisInvertedChanged + plot.setYAxisInverted(False) + + def setYAxisDownward(self): + """Configure the plot to use y-axis downward""" + plot = self.plot() + if plot is not None: + # This will trigger _yAxisInvertedChanged + plot.setYAxisInverted(True) + + def _yAxisInvertedChanged(self, inverted): + """Handle Plot set y axis inverted signal""" + isUpward = not inverted + icon, toolTip = self.STATE[isUpward, "icon"], self.STATE[isUpward, "state"] + self.setIcon(icon) + self.setToolTip(toolTip) + + +class ProfileToolButton(PlotToolButton): + """Button used in Profile3DToolbar to switch between 2D profile + and 1D profile.""" + STATE = None + """Lazy loaded states used to feed ProfileToolButton""" + + sigDimensionChanged = qt.Signal(int) + + def __init__(self, parent=None, plot=None): + if self.STATE is None: + self.STATE = { + (1, "icon"): icons.getQIcon('profile1D'), + (1, "state"): "1D profile is computed on visible image", + (1, "action"): "1D profile on visible image", + (2, "icon"): icons.getQIcon('profile2D'), + (2, "state"): "2D profile is computed, one 1D profile for each image in the stack", + (2, "action"): "2D profile on image stack"} + # Compute 1D profile + # Compute 2D profile + + super(ProfileToolButton, self).__init__(parent=parent, plot=plot) + + profile1DAction = self._createAction(1) + profile1DAction.triggered.connect(self.computeProfileIn1D) + profile1DAction.setIconVisibleInMenu(True) + + profile2DAction = self._createAction(2) + profile2DAction.triggered.connect(self.computeProfileIn2D) + profile2DAction.setIconVisibleInMenu(True) + + menu = qt.QMenu(self) + menu.addAction(profile1DAction) + menu.addAction(profile2DAction) + self.setMenu(menu) + self.setPopupMode(qt.QToolButton.InstantPopup) + menu.setTitle('Select profile dimension') + + def _createAction(self, profileDimension): + icon = self.STATE[profileDimension, "icon"] + text = self.STATE[profileDimension, "action"] + return qt.QAction(icon, text, self) + + def _profileDimensionChanged(self, profileDimension): + """Update icon in toolbar, emit number of dimensions for profile""" + self.setIcon(self.STATE[profileDimension, "icon"]) + self.setToolTip(self.STATE[profileDimension, "state"]) + self.sigDimensionChanged.emit(profileDimension) + + def computeProfileIn1D(self): + self._profileDimensionChanged(1) + + def computeProfileIn2D(self): + self._profileDimensionChanged(2) diff --git a/silx/gui/plot/PlotTools.py b/silx/gui/plot/PlotTools.py new file mode 100644 index 0000000..7158d0e --- /dev/null +++ b/silx/gui/plot/PlotTools.py @@ -0,0 +1,313 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Set of widgets to associate with a :class:'PlotWidget'. +""" + +from __future__ import division + +__authors__ = ["V.A. Sole", "T. Vincent"] +__license__ = "MIT" +__date__ = "03/03/2017" + + +import logging +import numbers +import traceback +import weakref + +import numpy + +from .. import qt + +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.DEBUG) + + +# PositionInfo ################################################################ + +class PositionInfo(qt.QWidget): + """QWidget displaying coords converted from data coords of the mouse. + + Provide this widget with a list of couple: + + - A name to display before the data + - A function that takes (x, y) as arguments and returns something that + gets converted to a string. + If the result is a float it is converted with '%.7g' format. + + To run the following sample code, a QApplication must be initialized. + First, create a PlotWindow and add a QToolBar where to place the + PositionInfo widget. + + >>> from silx.gui.plot import PlotWindow + >>> from silx.gui import qt + + >>> plot = PlotWindow() # Create a PlotWindow to add the widget to + >>> toolBar = qt.QToolBar() # Create a toolbar to place the widget in + >>> plot.addToolBar(qt.Qt.BottomToolBarArea, toolBar) # Add it to plot + + Then, create the PositionInfo widget and add it to the toolbar. + The PositionInfo widget is created with a list of converters, here + to display polar coordinates of the mouse position. + + >>> import numpy + >>> from silx.gui.plot.PlotTools import PositionInfo + + >>> position = PositionInfo(plot=plot, converters=[ + ... ('Radius', lambda x, y: numpy.sqrt(x*x + y*y)), + ... ('Angle', lambda x, y: numpy.degrees(numpy.arctan2(y, x)))]) + >>> toolBar.addWidget(position) # Add the widget to the toolbar + <...> + >>> plot.show() # To display the PlotWindow with the position widget + + :param plot: The PlotWidget this widget is displaying data coords from. + :param converters: List of name to display and conversion function from + (x, y) in data coords to displayed value. + If None, the default, it displays X and Y. + :type converters: Iterable of 2-tuple (str, function) + :param parent: Parent widget + """ + + def __init__(self, parent=None, plot=None, converters=None): + assert plot is not None + self._plotRef = weakref.ref(plot) + + super(PositionInfo, self).__init__(parent) + + if converters is None: + converters = (('X', lambda x, y: x), ('Y', lambda x, y: y)) + + self.autoSnapToActiveCurve = False + """Toggle snapping use position to active curve. + + - True to snap used coordinates to the active curve if the active curve + is displayed with symbols and mouse is close enough. + If the mouse is not close to a point of the curve, values are + displayed in red. + - False (the default) to always use mouse coordinates. + + """ + + self._fields = [] # To store (QLineEdit, name, function (x, y)->v) + + # Create a new layout with new widgets + layout = qt.QHBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + # layout.setSpacing(0) + + # Create all QLabel and store them with the corresponding converter + for name, func in converters: + layout.addWidget(qt.QLabel('<b>' + name + ':</b>')) + + contentWidget = qt.QLabel() + contentWidget.setText('------') + contentWidget.setTextInteractionFlags(qt.Qt.TextSelectableByMouse) + contentWidget.setFixedWidth( + contentWidget.fontMetrics().width('##############')) + layout.addWidget(contentWidget) + self._fields.append((contentWidget, name, func)) + + layout.addStretch(1) + self.setLayout(layout) + + # Connect to Plot events + plot.sigPlotSignal.connect(self._plotEvent) + + @property + def plot(self): + """The :class:`.PlotWindow` this widget is attached to.""" + return self._plotRef() + + def getConverters(self): + """Return the list of converters as 2-tuple (name, function).""" + return [(name, func) for _label, name, func in self._fields] + + def _plotEvent(self, event): + """Handle events from the Plot. + + :param dict event: Plot event + """ + if event['event'] == 'mouseMoved': + x, y = event['x'], event['y'] # Position in data + styleSheet = "color: rgb(0, 0, 0);" # Default style + + if self.autoSnapToActiveCurve and self.plot.getGraphCursor(): + # Check if near active curve with symbols. + + styleSheet = "color: rgb(255, 0, 0);" # Style far from curve + + activeCurve = self.plot.getActiveCurve() + if activeCurve: + xData = activeCurve.getXData(copy=False) + yData = activeCurve.getYData(copy=False) + if activeCurve.getSymbol(): # Only handled if symbols on curve + closestIndex = numpy.argmin( + pow(xData - x, 2) + pow(yData - y, 2)) + + xClosest = xData[closestIndex] + yClosest = yData[closestIndex] + + closestInPixels = self.plot.dataToPixel( + xClosest, yClosest, axis=activeCurve.getYAxis()) + if closestInPixels is not None: + xPixel, yPixel = event['xpixel'], event['ypixel'] + + if (abs(closestInPixels[0] - xPixel) < 5 and + abs(closestInPixels[1] - yPixel) < 5): + # Update label style sheet + styleSheet = "color: rgb(0, 0, 0);" + + # if close enough, wrap to data point coords + x, y = xClosest, yClosest + + for label, name, func in self._fields: + label.setStyleSheet(styleSheet) + + try: + value = func(x, y) + except: + label.setText('Error') + _logger.error( + "Error while converting coordinates (%f, %f)" + "with converter '%s'" % (x, y, name)) + _logger.error(traceback.format_exc()) + else: + if isinstance(value, numbers.Real): + value = '%.7g' % value # Use this for floats and int + else: + value = str(value) # Fallback for other types + label.setText(value) + + +# LimitsToolBar ############################################################## + +class LimitsToolBar(qt.QToolBar): + """QToolBar displaying and controlling the limits of a :class:`PlotWidget`. + + To run the following sample code, a QApplication must be initialized. + First, create a PlotWindow: + + >>> from silx.gui.plot import PlotWindow + >>> plot = PlotWindow() # Create a PlotWindow to add the toolbar to + + Then, create the LimitsToolBar and add it to the PlotWindow. + + >>> from silx.gui import qt + >>> from silx.gui.plot.PlotTools import LimitsToolBar + + >>> toolbar = LimitsToolBar(plot=plot) # Create the toolbar + >>> plot.addToolBar(qt.Qt.BottomToolBarArea, toolbar) # Add it to the plot + >>> plot.show() # To display the PlotWindow with the limits toolbar + + :param parent: See :class:`QToolBar`. + :param plot: :class:`PlotWidget` instance on which to operate. + :param str title: See :class:`QToolBar`. + """ + + class _FloatEdit(qt.QLineEdit): + """Field to edit a float value.""" + def __init__(self, value=None, *args, **kwargs): + qt.QLineEdit.__init__(self, *args, **kwargs) + self.setValidator(qt.QDoubleValidator()) + self.setFixedWidth(100) + self.setAlignment(qt.Qt.AlignLeft) + if value is not None: + self.setValue(value) + + def value(self): + return float(self.text()) + + def setValue(self, value): + self.setText('%g' % value) + + def __init__(self, parent=None, plot=None, title='Limits'): + super(LimitsToolBar, self).__init__(title, parent) + assert plot is not None + self._plot = plot + self._plot.sigPlotSignal.connect(self._plotWidgetSlot) + + self._initWidgets() + + @property + def plot(self): + """The :class:`PlotWidget` the toolbar is attached to.""" + return self._plot + + def _initWidgets(self): + """Create and init Toolbar widgets.""" + xMin, xMax = self.plot.getGraphXLimits() + yMin, yMax = self.plot.getGraphYLimits() + + self.addWidget(qt.QLabel('Limits: ')) + self.addWidget(qt.QLabel(' X: ')) + self._xMinFloatEdit = self._FloatEdit(xMin) + self._xMinFloatEdit.editingFinished[()].connect( + self._xFloatEditChanged) + self.addWidget(self._xMinFloatEdit) + + self._xMaxFloatEdit = self._FloatEdit(xMax) + self._xMaxFloatEdit.editingFinished[()].connect( + self._xFloatEditChanged) + self.addWidget(self._xMaxFloatEdit) + + self.addWidget(qt.QLabel(' Y: ')) + self._yMinFloatEdit = self._FloatEdit(yMin) + self._yMinFloatEdit.editingFinished[()].connect( + self._yFloatEditChanged) + self.addWidget(self._yMinFloatEdit) + + self._yMaxFloatEdit = self._FloatEdit(yMax) + self._yMaxFloatEdit.editingFinished[()].connect( + self._yFloatEditChanged) + self.addWidget(self._yMaxFloatEdit) + + def _plotWidgetSlot(self, event): + """Listen to :class:`PlotWidget` events.""" + if event['event'] not in ('limitsChanged',): + return + + xMin, xMax = self.plot.getGraphXLimits() + yMin, yMax = self.plot.getGraphYLimits() + + self._xMinFloatEdit.setValue(xMin) + self._xMaxFloatEdit.setValue(xMax) + self._yMinFloatEdit.setValue(yMin) + self._yMaxFloatEdit.setValue(yMax) + + def _xFloatEditChanged(self): + """Handle X limits changed from the GUI.""" + xMin, xMax = self._xMinFloatEdit.value(), self._xMaxFloatEdit.value() + if xMax < xMin: + xMin, xMax = xMax, xMin + + self.plot.setGraphXLimits(xMin, xMax) + + def _yFloatEditChanged(self): + """Handle Y limits changed from the GUI.""" + yMin, yMax = self._yMinFloatEdit.value(), self._yMaxFloatEdit.value() + if yMax < yMin: + yMin, yMax = yMax, yMin + + self.plot.setGraphYLimits(yMin, yMax) diff --git a/silx/gui/plot/PlotWidget.py b/silx/gui/plot/PlotWidget.py new file mode 100644 index 0000000..5666d56 --- /dev/null +++ b/silx/gui/plot/PlotWidget.py @@ -0,0 +1,267 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Qt widget providing Plot API for 1D and 2D data. + +This provides the plot API of :class:`silx.gui.plot.Plot.Plot` as a +Qt widget. +""" + +__authors__ = ["V.A. Sole", "T. Vincent"] +__license__ = "MIT" +__date__ = "22/02/2016" + + +import logging + +from . import Plot + +from .. import qt + + +_logger = logging.getLogger(__name__) + + +class PlotWidget(qt.QMainWindow, Plot.Plot): + """Qt Widget providing a 1D/2D plot. + + This widget is a QMainWindow. + It provides Qt signals for the Plot and add supports for panning + with arrow keys. + + :param parent: The parent of this widget or None. + :param backend: The backend to use for the plot (default: matplotlib). + See :class:`.Plot` for the list of supported backend. + :type backend: str or :class:`BackendBase.BackendBase` + """ + + sigPlotSignal = qt.Signal(object) + """Signal for all events of the plot. + + The signal information is provided as a dict. + See :class:`.Plot` for documentation of the content of the dict. + """ + + sigSetYAxisInverted = qt.Signal(bool) + """Signal emitted when Y axis orientation has changed""" + + sigSetXAxisLogarithmic = qt.Signal(bool) + """Signal emitted when X axis scale has changed""" + + sigSetYAxisLogarithmic = qt.Signal(bool) + """Signal emitted when Y axis scale has changed""" + + sigSetXAxisAutoScale = qt.Signal(bool) + """Signal emitted when X axis autoscale has changed""" + + sigSetYAxisAutoScale = qt.Signal(bool) + """Signal emitted when Y axis autoscale has changed""" + + sigSetKeepDataAspectRatio = qt.Signal(bool) + """Signal emitted when plot keep aspect ratio has changed""" + + sigSetGraphGrid = qt.Signal(str) + """Signal emitted when plot grid has changed""" + + sigSetGraphCursor = qt.Signal(bool) + """Signal emitted when plot crosshair cursor has changed""" + + sigSetPanWithArrowKeys = qt.Signal(bool) + """Signal emitted when pan with arrow keys has changed""" + + sigContentChanged = qt.Signal(str, str, str) + """Signal emitted when the content of the plot is changed. + + It provides 3 informations: + + - action: The change of the plot: 'add' or 'remove' + - kind: The kind of primitive changed: + 'curve', 'image', 'scatter', 'histogram', 'item' or 'marker' + - legend: The legend of the primitive changed. + """ + + sigActiveCurveChanged = qt.Signal(object, object) + """Signal emitted when the active curve has changed. + + It provides 2 informations: + + - previous: The legend of the previous active curve or None + - legend: The legend of the new active curve or None if no curve is active + """ + + sigActiveImageChanged = qt.Signal(object, object) + """Signal emitted when the active image has changed. + + It provides 2 informations: + + - previous: The legend of the previous active image or None + - legend: The legend of the new active image or None if no image is active + """ + + sigActiveScatterChanged = qt.Signal(object, object) + """Signal emitted when the active Scatter has changed. + + It provides following information: + + - previous: The legend of the previous active scatter or None + - legend: The legend of the new active image or None if no image is active + """ + + sigInteractiveModeChanged = qt.Signal(object) + """Signal emitted when the interactive mode has changed + + It provides the source as passed to :meth:`setInteractiveMode`. + """ + + def __init__(self, parent=None, backend=None, + legends=False, callback=None, **kw): + + if kw: + _logger.warning( + 'deprecated: __init__ extra arguments: %s', str(kw)) + if legends: + _logger.warning('deprecated: __init__ legend argument') + if callback: + _logger.warning('deprecated: __init__ callback argument') + + self._panWithArrowKeys = True + + qt.QMainWindow.__init__(self, parent) + if parent is not None: + # behave as a widget + self.setWindowFlags(qt.Qt.Widget) + else: + self.setWindowTitle('PlotWidget') + + Plot.Plot.__init__(self, parent, backend=backend) + + widget = self.getWidgetHandle() + if widget is not None: + self.setCentralWidget(widget) + else: + _logger.warning("Plot backend does not support widget") + + self.setFocusPolicy(qt.Qt.StrongFocus) + self.setFocus(qt.Qt.OtherFocusReason) + + def notify(self, event, **kwargs): + """Override :meth:`Plot.notify` to send Qt signals.""" + eventDict = kwargs.copy() + eventDict['event'] = event + self.sigPlotSignal.emit(eventDict) + + if event == 'setYAxisInverted': + self.sigSetYAxisInverted.emit(kwargs['state']) + elif event == 'setXAxisLogarithmic': + self.sigSetXAxisLogarithmic.emit(kwargs['state']) + elif event == 'setYAxisLogarithmic': + self.sigSetYAxisLogarithmic.emit(kwargs['state']) + elif event == 'setXAxisAutoScale': + self.sigSetXAxisAutoScale.emit(kwargs['state']) + elif event == 'setYAxisAutoScale': + self.sigSetYAxisAutoScale.emit(kwargs['state']) + elif event == 'setKeepDataAspectRatio': + self.sigSetKeepDataAspectRatio.emit(kwargs['state']) + elif event == 'setGraphGrid': + self.sigSetGraphGrid.emit(kwargs['which']) + elif event == 'setGraphCursor': + self.sigSetGraphCursor.emit(kwargs['state']) + elif event == 'contentChanged': + self.sigContentChanged.emit( + kwargs['action'], kwargs['kind'], kwargs['legend']) + elif event == 'activeCurveChanged': + self.sigActiveCurveChanged.emit( + kwargs['previous'], kwargs['legend']) + elif event == 'activeImageChanged': + self.sigActiveImageChanged.emit( + kwargs['previous'], kwargs['legend']) + elif event == 'activeScatterChanged': + self.sigActiveScatterChanged.emit( + kwargs['previous'], kwargs['legend']) + elif event == 'interactiveModeChanged': + self.sigInteractiveModeChanged.emit(kwargs['source']) + Plot.Plot.notify(self, event, **kwargs) + + # Panning with arrow keys + + def isPanWithArrowKeys(self): + """Returns whether or not panning the graph with arrow keys is enable. + + See :meth:`setPanWithArrowKeys`. + """ + return self._panWithArrowKeys + + def setPanWithArrowKeys(self, pan=False): + """Enable/Disable panning the graph with arrow keys. + + This grabs the keyboard. + + :param bool pan: True to enable panning, False to disable. + """ + pan = bool(pan) + panHasChanged = self._panWithArrowKeys != pan + + self._panWithArrowKeys = pan + if not self._panWithArrowKeys: + self.setFocusPolicy(qt.Qt.NoFocus) + else: + self.setFocusPolicy(qt.Qt.StrongFocus) + self.setFocus(qt.Qt.OtherFocusReason) + + if panHasChanged: + self.sigSetPanWithArrowKeys.emit(pan) + + # Dict to convert Qt arrow key code to direction str. + _ARROWS_TO_PAN_DIRECTION = { + qt.Qt.Key_Left: 'left', + qt.Qt.Key_Right: 'right', + qt.Qt.Key_Up: 'up', + qt.Qt.Key_Down: 'down' + } + + def keyPressEvent(self, event): + """Key event handler handling panning on arrow keys. + + Overrides base class implementation. + """ + key = event.key() + if self._panWithArrowKeys and key in self._ARROWS_TO_PAN_DIRECTION: + self.pan(self._ARROWS_TO_PAN_DIRECTION[key], factor=0.1) + + # Send a mouse move event to the plot widget to take into account + # that even if mouse didn't move on the screen, it moved relative + # to the plotted data. + qapp = qt.QApplication.instance() + event = qt.QMouseEvent( + qt.QEvent.MouseMove, + self.getWidgetHandle().mapFromGlobal(qt.QCursor.pos()), + qt.Qt.NoButton, + qapp.mouseButtons(), + qapp.keyboardModifiers()) + qapp.sendEvent(self.getWidgetHandle(), event) + + else: + # Only call base class implementation when key is not handled. + # See QWidget.keyPressEvent for details. + super(PlotWidget, self).keyPressEvent(event) diff --git a/silx/gui/plot/PlotWindow.py b/silx/gui/plot/PlotWindow.py new file mode 100644 index 0000000..ae25cfd --- /dev/null +++ b/silx/gui/plot/PlotWindow.py @@ -0,0 +1,766 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""A :class:`.PlotWidget` with additional toolbars. + +The :class:`PlotWindow` is a subclass of :class:`.PlotWidget`. +It provides the plot API fully defined in :class:`.Plot`. +""" + +__authors__ = ["V.A. Sole", "T. Vincent"] +__license__ = "MIT" +__date__ = "27/04/2017" + +import collections +import logging + +from silx.utils.decorators import deprecated + +from . import PlotWidget +from . import PlotActions +from . import PlotToolButtons +from .PlotTools import PositionInfo +from .Profile import ProfileToolBar +from .LegendSelector import LegendsDockWidget +from .CurvesROIWidget import CurvesROIDockWidget +from .MaskToolsWidget import MaskToolsDockWidget +try: + from ..console import IPythonDockWidget +except ImportError: + IPythonDockWidget = None + +from .. import qt + + +_logger = logging.getLogger(__name__) + + +class PlotWindow(PlotWidget): + """Qt Widget providing a 1D/2D plot area and additional tools. + + This widgets inherits from :class:`.PlotWidget` and provides its plot API. + + Initialiser parameters: + + :param parent: The parent of this widget or None. + :param backend: The backend to use for the plot (default: matplotlib). + See :class:`.Plot` for the list of supported backend. + :type backend: str or :class:`BackendBase.BackendBase` + :param bool resetzoom: Toggle visibility of reset zoom action. + :param bool autoScale: Toggle visibility of axes autoscale actions. + :param bool logScale: Toggle visibility of axes log scale actions. + :param bool grid: Toggle visibility of grid mode action. + :param bool curveStyle: Toggle visibility of curve style action. + :param bool colormap: Toggle visibility of colormap action. + :param bool aspectRatio: Toggle visibility of aspect ratio button. + :param bool yInverted: Toggle visibility of Y axis direction button. + :param bool copy: Toggle visibility of copy action. + :param bool save: Toggle visibility of save action. + :param bool print_: Toggle visibility of print action. + :param bool control: True to display an Options button with a sub-menu + to show legends, toggle crosshair and pan with arrows. + (Default: False) + :param position: True to display widget with (x, y) mouse position + (Default: False). + It also supports a list of (name, funct(x, y)->value) + to customize the displayed values. + See :class:`silx.gui.plot.PlotTools.PositionInfo`. + :param bool roi: Toggle visibilty of ROI action. + :param bool mask: Toggle visibilty of mask action. + :param bool fit: Toggle visibilty of fit action. + """ + + def __init__(self, parent=None, backend=None, + resetzoom=True, autoScale=True, logScale=True, grid=True, + curveStyle=True, colormap=True, + aspectRatio=True, yInverted=True, + copy=True, save=True, print_=True, + control=False, position=False, + roi=True, mask=True, fit=False): + super(PlotWindow, self).__init__(parent=parent, backend=backend) + if parent is None: + self.setWindowTitle('PlotWindow') + + self._dockWidgets = [] + + # lazy loaded dock widgets + self._legendsDockWidget = None + self._curvesROIDockWidget = None + self._maskToolsDockWidget = None + + # Init actions + self.group = qt.QActionGroup(self) + self.group.setExclusive(False) + + self.resetZoomAction = self.group.addAction(PlotActions.ResetZoomAction(self)) + self.resetZoomAction.setVisible(resetzoom) + self.addAction(self.resetZoomAction) + + self.zoomInAction = PlotActions.ZoomInAction(self) + self.addAction(self.zoomInAction) + + self.zoomOutAction = PlotActions.ZoomOutAction(self) + self.addAction(self.zoomOutAction) + + self.xAxisAutoScaleAction = self.group.addAction( + PlotActions.XAxisAutoScaleAction(self)) + self.xAxisAutoScaleAction.setVisible(autoScale) + self.addAction(self.xAxisAutoScaleAction) + + self.yAxisAutoScaleAction = self.group.addAction( + PlotActions.YAxisAutoScaleAction(self)) + self.yAxisAutoScaleAction.setVisible(autoScale) + self.addAction(self.yAxisAutoScaleAction) + + self.xAxisLogarithmicAction = self.group.addAction( + PlotActions.XAxisLogarithmicAction(self)) + self.xAxisLogarithmicAction.setVisible(logScale) + self.addAction(self.xAxisLogarithmicAction) + + self.yAxisLogarithmicAction = self.group.addAction( + PlotActions.YAxisLogarithmicAction(self)) + self.yAxisLogarithmicAction.setVisible(logScale) + self.addAction(self.yAxisLogarithmicAction) + + self.gridAction = self.group.addAction( + PlotActions.GridAction(self, gridMode='both')) + self.gridAction.setVisible(grid) + self.addAction(self.gridAction) + + self.curveStyleAction = self.group.addAction(PlotActions.CurveStyleAction(self)) + self.curveStyleAction.setVisible(curveStyle) + self.addAction(self.curveStyleAction) + + self.colormapAction = self.group.addAction(PlotActions.ColormapAction(self)) + self.colormapAction.setVisible(colormap) + self.addAction(self.colormapAction) + + self.keepDataAspectRatioButton = PlotToolButtons.AspectToolButton( + parent=self, plot=self) + self.keepDataAspectRatioButton.setVisible(aspectRatio) + + self.yAxisInvertedButton = PlotToolButtons.YAxisOriginToolButton( + parent=self, plot=self) + self.yAxisInvertedButton.setVisible(yInverted) + + self.group.addAction(self.getRoiAction()) + self.getRoiAction().setVisible(roi) + + self.group.addAction(self.getMaskAction()) + self.getMaskAction().setVisible(mask) + + self._intensityHistoAction = self.group.addAction( + PlotActions.PixelIntensitiesHistoAction(self)) + self._intensityHistoAction.setVisible(False) + + self._medianFilter2DAction = self.group.addAction( + PlotActions.MedianFilter2DAction(self)) + self._medianFilter2DAction.setVisible(False) + + self._medianFilter1DAction = self.group.addAction( + PlotActions.MedianFilter1DAction(self)) + self._medianFilter1DAction.setVisible(False) + + self._separator = qt.QAction('separator', self) + self._separator.setSeparator(True) + self.group.addAction(self._separator) + + self.copyAction = self.group.addAction(PlotActions.CopyAction(self)) + self.copyAction.setVisible(copy) + self.addAction(self.copyAction) + + self.saveAction = self.group.addAction(PlotActions.SaveAction(self)) + self.saveAction.setVisible(save) + self.addAction(self.saveAction) + + self.printAction = self.group.addAction(PlotActions.PrintAction(self)) + self.printAction.setVisible(print_) + self.addAction(self.printAction) + + self.fitAction = self.group.addAction(PlotActions.FitAction(self)) + self.fitAction.setVisible(fit) + self.addAction(self.fitAction) + + # lazy loaded actions needed by the controlButton menu + self._consoleAction = None + self._panWithArrowKeysAction = None + self._crosshairAction = None + + if control or position: + hbox = qt.QHBoxLayout() + hbox.setContentsMargins(0, 0, 0, 0) + + if control: + self.controlButton = qt.QToolButton() + self.controlButton.setText("Options") + self.controlButton.setToolButtonStyle(qt.Qt.ToolButtonTextBesideIcon) + self.controlButton.setAutoRaise(True) + self.controlButton.setPopupMode(qt.QToolButton.InstantPopup) + menu = qt.QMenu(self) + menu.aboutToShow.connect(self._customControlButtonMenu) + self.controlButton.setMenu(menu) + + hbox.addWidget(self.controlButton) + + if position: # Add PositionInfo widget to the bottom of the plot + if isinstance(position, collections.Iterable): + # Use position as a set of converters + converters = position + else: + converters = None + self.positionWidget = PositionInfo( + plot=self, converters=converters) + self.positionWidget.autoSnapToActiveCurve = True + + hbox.addWidget(self.positionWidget) + + hbox.addStretch(1) + bottomBar = qt.QWidget() + bottomBar.setLayout(hbox) + + layout = qt.QVBoxLayout() + layout.setSpacing(0) + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self.getWidgetHandle()) + layout.addWidget(bottomBar) + layout.setStretch(0, 1) + + centralWidget = qt.QWidget() + centralWidget.setLayout(layout) + self.setCentralWidget(centralWidget) + + # Creating the toolbar also create actions for toolbuttons + self._toolbar = self._createToolBar(title='Plot', parent=None) + self.addToolBar(self._toolbar) + + def getSelectionMask(self): + """Return the current mask handled by :attr:`maskToolsDockWidget`. + + :return: The array of the mask with dimension of the 'active' image. + If there is no active image, an empty array is returned. + :rtype: 2D numpy.ndarray of uint8 + """ + return self.getMaskToolsDockWidget().getSelectionMask() + + def setSelectionMask(self, mask): + """Set the mask handled by :attr:`maskToolsDockWidget`. + + If the provided mask has not the same dimension as the 'active' + image, it will by cropped or padded. + + :param mask: The array to use for the mask. + :type mask: numpy.ndarray of uint8 of dimension 2, C-contiguous. + Array of other types are converted. + :return: True if success, False if failed + """ + return bool(self.getMaskToolsDockWidget().setSelectionMask(mask)) + + def _toggleConsoleVisibility(self, is_checked=False): + """Create IPythonDockWidget if needed, + show it or hide it.""" + # create widget if needed (first call) + if not hasattr(self, '_consoleDockWidget'): + available_vars = {"plt": self} + banner = "The variable 'plt' is available. Use the 'whos' " + banner += "and 'help(plt)' commands for more information.\n\n" + self._consoleDockWidget = IPythonDockWidget( + available_vars=available_vars, + custom_banner=banner, + parent=self) + self.addTabbedDockWidget(self._consoleDockWidget) + self._consoleDockWidget.visibilityChanged.connect( + self.getConsoleAction().setChecked) + + self._consoleDockWidget.setVisible(is_checked) + + def _createToolBar(self, title, parent): + """Create a QToolBar from the QAction of the PlotWindow. + + :param str title: The title of the QMenu + :param qt.QWidget parent: See :class:`QToolBar` + """ + toolbar = qt.QToolBar(title, parent) + + # Order widgets with actions + objects = self.group.actions() + + # Add push buttons to list + index = objects.index(self.colormapAction) + objects.insert(index + 1, self.keepDataAspectRatioButton) + objects.insert(index + 2, self.yAxisInvertedButton) + + for obj in objects: + if isinstance(obj, qt.QAction): + toolbar.addAction(obj) + else: + # Add action for toolbutton in order to allow changing + # visibility (see doc QToolBar.addWidget doc) + if obj is self.keepDataAspectRatioButton: + self.keepDataAspectRatioAction = toolbar.addWidget(obj) + elif obj is self.yAxisInvertedButton: + self.yAxisInvertedAction = toolbar.addWidget(obj) + else: + raise RuntimeError() + return toolbar + + def toolBar(self): + """Return a QToolBar from the QAction of the PlotWindow. + """ + return self._toolbar + + def menu(self, title='Plot', parent=None): + """Return a QMenu from the QAction of the PlotWindow. + + :param str title: The title of the QMenu + :param parent: See :class:`QMenu` + """ + menu = qt.QMenu(title, parent) + for action in self.group.actions(): + menu.addAction(action) + return menu + + def _customControlButtonMenu(self): + """Display Options button sub-menu.""" + controlMenu = self.controlButton.menu() + controlMenu.clear() + controlMenu.addAction(self.getLegendsDockWidget().toggleViewAction()) + controlMenu.addAction(self.getRoiAction()) + controlMenu.addAction(self.getMaskAction()) + controlMenu.addAction(self.getConsoleAction()) + + controlMenu.addSeparator() + controlMenu.addAction(self.getCrosshairAction()) + controlMenu.addAction(self.getPanWithArrowKeysAction()) + + def addTabbedDockWidget(self, dock_widget): + """Add a dock widget as a new tab if there are already dock widgets + in the plot. When the first tab is added, the area is chosen + depending on the plot geometry: + it the window is much wider than it is high, the right dock area + is used, else the bottom dock area is used. + + :param dock_widget: Instance of :class:`QDockWidget` to be added. + """ + if dock_widget not in self._dockWidgets: + self._dockWidgets.append(dock_widget) + if len(self._dockWidgets) == 1: + # The first created dock widget must be added to a Widget area + width = self.centralWidget().width() + height = self.centralWidget().height() + if width > (2.0 * height) and width > 1000: + area = qt.Qt.RightDockWidgetArea + else: + area = qt.Qt.BottomDockWidgetArea + self.addDockWidget(area, dock_widget) + else: + # Other dock widgets are added as tabs to the same widget area + self.tabifyDockWidget(self._dockWidgets[0], + dock_widget) + + # getters for dock widgets + @property + @deprecated(replacement="getLegendsDockWidget()", since_version="0.4.0") + def legendsDockWidget(self): + return self.getLegendsDockWidget() + + def getLegendsDockWidget(self): + """DockWidget with Legend panel""" + if self._legendsDockWidget is None: + self._legendsDockWidget = LegendsDockWidget(plot=self) + self._legendsDockWidget.hide() + self.addTabbedDockWidget(self._legendsDockWidget) + return self._legendsDockWidget + + @property + @deprecated(replacement="getCurvesRoiDockWidget()", since_version="0.4.0") + def curvesROIDockWidget(self): + return self.getCurvesRoiDockWidget() + + def getCurvesRoiDockWidget(self): + """DockWidget with curves' ROI panel (lazy-loaded). + + The widget returned is a :class:`CurvesROIDockWidget`. + Its central widget is a :class:`CurvesROIWidget` + accessible as :attr:`CurvesROIDockWidget.roiWidget`. + + :class:`silx.gui.plot.CurvesROIWidget.CurvesROIWidget` offers a getter + and a setter for the ROI data: + + - :meth:`CurvesROIWidget.getRois` + - :meth:`CurvesROIWidget.setRois` + """ + if self._curvesROIDockWidget is None: + self._curvesROIDockWidget = CurvesROIDockWidget( + plot=self, name='Regions Of Interest') + self._curvesROIDockWidget.hide() + self.addTabbedDockWidget(self._curvesROIDockWidget) + return self._curvesROIDockWidget + + @property + @deprecated(replacement="getMaskToolsDockWidget()", since_version="0.4.0") + def maskToolsDockWidget(self): + return self.getMaskToolsDockWidget() + + def getMaskToolsDockWidget(self): + """DockWidget with image mask panel (lazy-loaded).""" + if self._maskToolsDockWidget is None: + self._maskToolsDockWidget = MaskToolsDockWidget( + plot=self, name='Mask') + self._maskToolsDockWidget.hide() + self.addTabbedDockWidget(self._maskToolsDockWidget) + return self._maskToolsDockWidget + + # getters for actions + @property + @deprecated(replacement="getConsoleAction()", since_version="0.4.0") + def consoleAction(self): + return self.getConsoleAction() + + def getConsoleAction(self): + """QAction handling the IPython console activation. + + By default, it is connected to a method that initializes the + console widget the first time the user clicks the "Console" menu + button. The following clicks, after initialization is done, + will toggle the visibility of the console widget. + + :rtype: QAction + """ + if self._consoleAction is None: + self._consoleAction = qt.QAction('Console', self) + self._consoleAction.setCheckable(True) + if IPythonDockWidget is not None: + self._consoleAction.toggled.connect(self._toggleConsoleVisibility) + else: + self._consoleAction.setEnabled(False) + return self._consoleAction + + @property + @deprecated(replacement="getCrosshairAction()", since_version="0.4.0") + def crosshairAction(self): + return self.getCrosshairAction() + + def getCrosshairAction(self): + """Action toggling crosshair cursor mode. + + :rtype: PlotActions.PlotAction + """ + if self._crosshairAction is None: + self._crosshairAction = PlotActions.CrosshairAction(self, color='red') + return self._crosshairAction + + @property + @deprecated(replacement="getMaskAction()", since_version="0.4.0") + def maskAction(self): + return self.getMaskAction() + + def getMaskAction(self): + """QAction toggling image mask dock widget + + :rtype: QAction + """ + return self.getMaskToolsDockWidget().toggleViewAction() + + @property + @deprecated(replacement="getPanWithArrowKeysAction()", + since_version="0.4.0") + def panWithArrowKeysAction(self): + return self.getPanWithArrowKeysAction() + + def getPanWithArrowKeysAction(self): + """Action toggling pan with arrow keys. + + :rtype: PlotActions.PlotAction + """ + if self._panWithArrowKeysAction is None: + self._panWithArrowKeysAction = PlotActions.PanWithArrowKeysAction(self) + return self._panWithArrowKeysAction + + @property + @deprecated(replacement="getRoiAction()", since_version="0.4.0") + def roiAction(self): + return self.getRoiAction() + + def getRoiAction(self): + """QAction toggling curve ROI dock widget + + :rtype: QAction + """ + return self.getCurvesRoiDockWidget().toggleViewAction() + + def getResetZoomAction(self): + """Action resetting the zoom + + :rtype: PlotActions.PlotAction + """ + return self.resetZoomAction + + def getZoomInAction(self): + """Action to zoom in + + :rtype: PlotActions.PlotAction + """ + return self.zoomInAction + + def getZoomOutAction(self): + """Action to zoom out + + :rtype: PlotActions.PlotAction + """ + return self.zoomOutAction + + def getXAxisAutoScaleAction(self): + """Action to toggle the X axis autoscale on zoom reset + + :rtype: PlotActions.PlotAction + """ + return self.xAxisAutoScaleAction + + def getYAxisAutoScaleAction(self): + """Action to toggle the Y axis autoscale on zoom reset + + :rtype: PlotActions.PlotAction + """ + return self.yAxisAutoScaleAction + + def getXAxisLogarithmicAction(self): + """Action to toggle logarithmic X axis + + :rtype: PlotActions.PlotAction + """ + return self.xAxisLogarithmicAction + + def getYAxisLogarithmicAction(self): + """Action to toggle logarithmic Y axis + + :rtype: PlotActions.PlotAction + """ + return self.yAxisLogarithmicAction + + def getGridAction(self): + """Action to toggle the grid visibility in the plot + + :rtype: PlotActions.PlotAction + """ + return self.gridAction + + def getCurveStyleAction(self): + """Action to change curve line and markers styles + + :rtype: PlotActions.PlotAction + """ + return self.curveStyleAction + + def getColormapAction(self): + """Action open a colormap dialog to change active image + and default colormap. + + :rtype: PlotActions.PlotAction + """ + return self.colormapAction + + def getKeepDataAspectRatioButton(self): + """Button to toggle aspect ratio preservation + + :rtype: PlotToolButtons.AspectToolButton + """ + return self.keepDataAspectRatioButton + + def getKeepDataAspectRatioAction(self): + """Action associated to keepDataAspectRatioButton. + Use this to change the visibility of keepDataAspectRatioButton in the + toolbar (See :meth:`QToolBar.addWidget` documentation). + + :rtype: PlotActions.PlotAction + """ + return self.keepDataAspectRatioButton + + def getYAxisInvertedButton(self): + """Button to switch the Y axis orientation + + :rtype: PlotToolButtons.YAxisOriginToolButton + """ + return self.yAxisInvertedButton + + def getYAxisInvertedAction(self): + """Action associated to yAxisInvertedButton. + Use this to change the visibility yAxisInvertedButton in the toolbar. + (See :meth:`QToolBar.addWidget` documentation). + + :rtype: PlotActions.PlotAction + """ + return self.yAxisInvertedAction + + def getIntensityHistogramAction(self): + """Action toggling the histogram intensity Plot widget + + :rtype: PlotActions.PlotAction + """ + return self._intensityHistoAction + + def getCopyAction(self): + """Action to copy plot snapshot to clipboard + + :rtype: PlotActions.PlotAction + """ + return self.copyAction + + def getSaveAction(self): + """Action to save plot + + :rtype: PlotActions.PlotAction + """ + return self.saveAction + + def getPrintAction(self): + """Action to print plot + + :rtype: PlotActions.PlotAction + """ + return self.printAction + + def getFitAction(self): + """Action to fit selected curve + + :rtype: PlotActions.PlotAction + """ + return self.fitAction + + def getMedianFilter1DAction(self): + """Action toggling the 1D median filter + + :rtype: PlotActions.PlotAction + """ + return self._medianFilter1DAction + + def getMedianFilter2DAction(self): + """Action toggling the 2D median filter + + :rtype: PlotActions.PlotAction + """ + return self._medianFilter2DAction + + +class Plot1D(PlotWindow): + """PlotWindow with tools specific for curves. + + This widgets provides the plot API of :class:`.PlotWidget`. + + :param parent: The parent of this widget + :param backend: The backend to use for the plot (default: matplotlib). + See :class:`.Plot` for the list of supported backend. + :type backend: str or :class:`BackendBase.BackendBase` + """ + + def __init__(self, parent=None, backend=None): + super(Plot1D, self).__init__(parent=parent, backend=backend, + resetzoom=True, autoScale=True, + logScale=True, grid=True, + curveStyle=True, colormap=False, + aspectRatio=False, yInverted=False, + copy=True, save=True, print_=True, + control=True, position=True, + roi=True, mask=False, fit=True) + if parent is None: + self.setWindowTitle('Plot1D') + self.setGraphXLabel('X') + self.setGraphYLabel('Y') + + +class Plot2D(PlotWindow): + """PlotWindow with a toolbar specific for images. + + This widgets provides the plot API of :~:`.PlotWidget`. + + :param parent: The parent of this widget + :param backend: The backend to use for the plot (default: matplotlib). + See :class:`.Plot` for the list of supported backend. + :type backend: str or :class:`BackendBase.BackendBase` + """ + + def __init__(self, parent=None, backend=None): + # List of information to display at the bottom of the plot + posInfo = [ + ('X', lambda x, y: x), + ('Y', lambda x, y: y), + ('Data', self._getImageValue)] + + super(Plot2D, self).__init__(parent=parent, backend=backend, + resetzoom=True, autoScale=False, + logScale=False, grid=False, + curveStyle=False, colormap=True, + aspectRatio=True, yInverted=True, + copy=True, save=True, print_=True, + control=False, position=posInfo, + roi=False, mask=True) + if parent is None: + self.setWindowTitle('Plot2D') + self.setGraphXLabel('Columns') + self.setGraphYLabel('Rows') + + self.profile = ProfileToolBar(plot=self) + + self.addToolBar(self.profile) + + def _getImageValue(self, x, y): + """Get value of top most image at position (x, y) + + :param float x: X position in plot coordinates + :param float y: Y position in plot coordinates + :return: The value at that point or '-' + """ + value = '-' + valueZ = - float('inf') + + for image in self.getAllImages(): + data = image.getData(copy=False) + if image.getZValue() >= valueZ: # This image is over the previous one + ox, oy = image.getOrigin() + sx, sy = image.getScale() + row, col = (y - oy) / sy, (x - ox) / sx + if row >= 0 and col >= 0: + # Test positive before cast otherwise issue with int(-0.5) = 0 + row, col = int(row), int(col) + if (row < data.shape[0] and col < data.shape[1]): + value = data[row, col] + valueZ = image.getZValue() + return value + + def getProfileToolbar(self): + """Profile tools attached to this plot + + See :class:`silx.gui.plot.Profile.ProfileToolBar` + """ + return self.profile + + @deprecated(replacement="getProfilePlot", since_version="0.5.0") + def getProfileWindow(self): + return self.getProfilePlot() + + def getProfilePlot(self): + """Return plot window used to display profile curve. + + :return: :class:`Plot1D` + """ + return self.profile.getProfilePlot() diff --git a/silx/gui/plot/Profile.py b/silx/gui/plot/Profile.py new file mode 100644 index 0000000..a11b3f0 --- /dev/null +++ b/silx/gui/plot/Profile.py @@ -0,0 +1,741 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Utility functions, toolbars and actions to create profile on images +and stacks of images""" + + +__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel", "H. Payno"] +__license__ = "MIT" +__date__ = "24/04/2017" + + +import numpy + +from silx.image.bilinear import BilinearImage + +from .. import icons +from .. import qt +from . import items +from .Colors import cursorColorForColormap +from .PlotActions import PlotAction +from .PlotToolButtons import ProfileToolButton +from .ProfileMainWindow import ProfileMainWindow + +from silx.utils.decorators import deprecated + + +def _alignedFullProfile(data, origin, scale, position, roiWidth, axis): + """Get a profile along one axis on a stack of images + + :param numpy.ndarray data: 3D volume (stack of 2D images) + The first dimension is the image index. + :param origin: Origin of image in plot (ox, oy) + :param scale: Scale of image in plot (sx, sy) + :param float position: Position of profile line in plot coords + on the axis orthogonal to the profile direction. + :param int roiWidth: Width of the profile in image pixels. + :param int axis: 0 for horizontal profile, 1 for vertical. + :return: profile image + effective ROI area corners in plot coords + """ + assert axis in (0, 1) + assert len(data.shape) == 3 + + # Convert from plot to image coords + imgPos = int((position - origin[1 - axis]) / scale[1 - axis]) + + if axis == 1: # Vertical profile + # Transpose image to always do a horizontal profile + data = numpy.transpose(data, (0, 2, 1)) + + nimages, height, width = data.shape + + roiWidth = min(height, roiWidth) # Clip roi width to image size + + # Get [start, end[ coords of the roi in the data + start = int(int(imgPos) + 0.5 - roiWidth / 2.) + start = min(max(0, start), height - roiWidth) + end = start + roiWidth + + if start < height and end > 0: + profile = data[:, max(0, start):min(end, height), :].mean( + axis=1, dtype=numpy.float32) + else: + profile = numpy.zeros((nimages, width), dtype=numpy.float32) + + # Compute effective ROI in plot coords + profileBounds = numpy.array( + (0, width, width, 0), + dtype=numpy.float32) * scale[axis] + origin[axis] + roiBounds = numpy.array( + (start, start, end, end), + dtype=numpy.float32) * scale[1 - axis] + origin[1 - axis] + + if axis == 0: # Horizontal profile + area = profileBounds, roiBounds + else: # vertical profile + area = roiBounds, profileBounds + + return profile, area + + +def _alignedPartialProfile(data, rowRange, colRange, axis): + """Mean of a rectangular region (ROI) of a stack of images + along a given axis. + + Returned values and all parameters are in image coordinates. + + :param numpy.ndarray data: 3D volume (stack of 2D images) + The first dimension is the image index. + :param rowRange: [min, max[ of ROI rows (upper bound excluded). + :type rowRange: 2-tuple of int (min, max) with min < max + :param colRange: [min, max[ of ROI columns (upper bound excluded). + :type colRange: 2-tuple of int (min, max) with min < max + :param int axis: The axis along which to take the profile of the ROI. + 0: Sum rows along columns. + 1: Sum columns along rows. + :return: Profile image along the ROI as the mean of the intersection + of the ROI and the image. + """ + assert axis in (0, 1) + assert len(data.shape) == 3 + assert rowRange[0] < rowRange[1] + assert colRange[0] < colRange[1] + + nimages, height, width = data.shape + + # Range aligned with the integration direction + profileRange = colRange if axis == 0 else rowRange + + profileLength = abs(profileRange[1] - profileRange[0]) + + # Subset of the image to use as intersection of ROI and image + rowStart = min(max(0, rowRange[0]), height) + rowEnd = min(max(0, rowRange[1]), height) + colStart = min(max(0, colRange[0]), width) + colEnd = min(max(0, colRange[1]), width) + + imgProfile = numpy.mean(data[:, rowStart:rowEnd, colStart:colEnd], + axis=axis + 1, dtype=numpy.float32) + + # Profile including out of bound area + profile = numpy.zeros((nimages, profileLength), dtype=numpy.float32) + + # Place imgProfile in full profile + offset = - min(0, profileRange[0]) + profile[:, offset:offset + imgProfile.shape[1]] = imgProfile + + return profile + + +def createProfile(roiInfo, currentData, origin, scale, lineWidth): + """Create the profile line for the the given image. + + :param roiInfo: information about the ROI: start point, end point and + type ("X", "Y", "D") + :param numpy.ndarray currentData: the 2D image or the 3D stack of images + on which we compute the profile. + :param origin: (ox, oy) the offset from origin + :type origin: 2-tuple of float + :param scale: (sx, sy) the scale to use + :type scale: 2-tuple of float + :param int lineWidth: width of the profile line + :return: `profile, area, profileName, xLabel`, where: + - profile is a 2D array of the profiles of the stack of images. + For a single image, the profile is a curve, so this parameter + has a shape *(1, len(curve))* + - area is a tuple of two 1D arrays with 4 values each. They represent + the effective ROI area corners in plot coords. + - profileName is a string describing the ROI, meant to be used as + title of the profile plot + - xLabel is a string describing the meaning of the X axis on the + profile plot ("rows", "columns", "distance") + + :rtype: tuple(ndarray, (ndarray, ndarray), str, str) + """ + if currentData is None or roiInfo is None or lineWidth is None: + raise ValueError("createProfile called with invalide arguments") + + # force 3D data (stack of images) + if len(currentData.shape) == 2: + currentData3D = currentData.reshape((1,) + currentData.shape) + elif len(currentData.shape) == 3: + currentData3D = currentData + + roiWidth = max(1, lineWidth) + roiStart, roiEnd, lineProjectionMode = roiInfo + + if lineProjectionMode == 'X': # Horizontal profile on the whole image + profile, area = _alignedFullProfile(currentData3D, + origin, scale, + roiStart[1], roiWidth, + axis=0) + + yMin, yMax = min(area[1]), max(area[1]) - 1 + if roiWidth <= 1: + profileName = 'Y = %g' % yMin + else: + profileName = 'Y = [%g, %g]' % (yMin, yMax) + xLabel = 'Columns' + + elif lineProjectionMode == 'Y': # Vertical profile on the whole image + profile, area = _alignedFullProfile(currentData3D, + origin, scale, + roiStart[0], roiWidth, + axis=1) + + xMin, xMax = min(area[0]), max(area[0]) - 1 + if roiWidth <= 1: + profileName = 'X = %g' % xMin + else: + profileName = 'X = [%g, %g]' % (xMin, xMax) + xLabel = 'Rows' + + else: # Free line profile + + # Convert start and end points in image coords as (row, col) + startPt = ((roiStart[1] - origin[1]) / scale[1], + (roiStart[0] - origin[0]) / scale[0]) + endPt = ((roiEnd[1] - origin[1]) / scale[1], + (roiEnd[0] - origin[0]) / scale[0]) + + if (int(startPt[0]) == int(endPt[0]) or + int(startPt[1]) == int(endPt[1])): + # Profile is aligned with one of the axes + + # Convert to int + startPt = int(startPt[0]), int(startPt[1]) + endPt = int(endPt[0]), int(endPt[1]) + + # Ensure startPt <= endPt + if startPt[0] > endPt[0] or startPt[1] > endPt[1]: + startPt, endPt = endPt, startPt + + if startPt[0] == endPt[0]: # Row aligned + rowRange = (int(startPt[0] + 0.5 - 0.5 * roiWidth), + int(startPt[0] + 0.5 + 0.5 * roiWidth)) + colRange = startPt[1], endPt[1] + 1 + profile = _alignedPartialProfile(currentData3D, + rowRange, colRange, + axis=0) + + else: # Column aligned + rowRange = startPt[0], endPt[0] + 1 + colRange = (int(startPt[1] + 0.5 - 0.5 * roiWidth), + int(startPt[1] + 0.5 + 0.5 * roiWidth)) + profile = _alignedPartialProfile(currentData3D, + rowRange, colRange, + axis=1) + + # Convert ranges to plot coords to draw ROI area + area = ( + numpy.array( + (colRange[0], colRange[1], colRange[1], colRange[0]), + dtype=numpy.float32) * scale[0] + origin[0], + numpy.array( + (rowRange[0], rowRange[0], rowRange[1], rowRange[1]), + dtype=numpy.float32) * scale[1] + origin[1]) + + else: # General case: use bilinear interpolation + + # Ensure startPt <= endPt + if (startPt[1] > endPt[1] or ( + startPt[1] == endPt[1] and startPt[0] > endPt[0])): + startPt, endPt = endPt, startPt + + profile = [] + for slice_idx in range(currentData3D.shape[0]): + bilinear = BilinearImage(currentData3D[slice_idx, :, :]) + + profile.append(bilinear.profile_line( + (startPt[0] - 0.5, startPt[1] - 0.5), + (endPt[0] - 0.5, endPt[1] - 0.5), + roiWidth)) + profile = numpy.array(profile) + + # Extend ROI with half a pixel on each end, and + # Convert back to plot coords (x, y) + length = numpy.sqrt((endPt[0] - startPt[0]) ** 2 + + (endPt[1] - startPt[1]) ** 2) + dRow = (endPt[0] - startPt[0]) / length + dCol = (endPt[1] - startPt[1]) / length + + # Extend ROI with half a pixel on each end + startPt = startPt[0] - 0.5 * dRow, startPt[1] - 0.5 * dCol + endPt = endPt[0] + 0.5 * dRow, endPt[1] + 0.5 * dCol + + # Rotate deltas by 90 degrees to apply line width + dRow, dCol = dCol, -dRow + + area = ( + numpy.array((startPt[1] - 0.5 * roiWidth * dCol, + startPt[1] + 0.5 * roiWidth * dCol, + endPt[1] + 0.5 * roiWidth * dCol, + endPt[1] - 0.5 * roiWidth * dCol), + dtype=numpy.float32) * scale[0] + origin[0], + numpy.array((startPt[0] - 0.5 * roiWidth * dRow, + startPt[0] + 0.5 * roiWidth * dRow, + endPt[0] + 0.5 * roiWidth * dRow, + endPt[0] - 0.5 * roiWidth * dRow), + dtype=numpy.float32) * scale[1] + origin[1]) + + y0, x0 = startPt + y1, x1 = endPt + if x1 == x0 or y1 == y0: + profileName = 'From (%g, %g) to (%g, %g)' % (x0, y0, x1, y1) + else: + m = (y1 - y0) / (x1 - x0) + b = y0 - m * x0 + profileName = 'y = %g * x %+g ; width=%d' % (m, b, roiWidth) + xLabel = 'Distance' + + return profile, area, profileName, xLabel + + +# ProfileToolBar ############################################################## + +class ProfileToolBar(qt.QToolBar): + """QToolBar providing profile tools operating on a :class:`PlotWindow`. + + Attributes: + + - plot: Associated :class:`PlotWindow` on which the profile line is drawn. + - actionGroup: :class:`QActionGroup` of available actions. + + To run the following sample code, a QApplication must be initialized. + First, create a PlotWindow and add a :class:`ProfileToolBar`. + + >>> from silx.gui.plot import PlotWindow + >>> from silx.gui.plot.Profile import ProfileToolBar + + >>> plot = PlotWindow() # Create a PlotWindow + >>> toolBar = ProfileToolBar(plot=plot) # Create a profile toolbar + >>> plot.addToolBar(toolBar) # Add it to plot + >>> plot.show() # To display the PlotWindow with the profile toolbar + + :param plot: :class:`PlotWindow` instance on which to operate. + :param profileWindow: Plot widget instance where to + display the profile curve or None to create one. + :param str title: See :class:`QToolBar`. + :param parent: See :class:`QToolBar`. + """ + # TODO Make it a QActionGroup instead of a QToolBar + + _POLYGON_LEGEND = '__ProfileToolBar_ROI_Polygon' + + def __init__(self, parent=None, plot=None, profileWindow=None, + title='Profile Selection'): + super(ProfileToolBar, self).__init__(title, parent) + assert plot is not None + self.plot = plot + + self._overlayColor = None + self._defaultOverlayColor = 'red' # update when active image change + + self._roiInfo = None # Store start and end points and type of ROI + + self._profileWindow = profileWindow + """User provided plot widget in which the profile curve is plotted. + None if no custom profile plot was provided.""" + + self._profileMainWindow = None + """Main window providing 2 profile plot widgets for 1D or 2D profiles. + The window provides two public methods + - :meth:`setProfileDimensions` + - :meth:`getPlot`: return handle on the actual plot widget + currently being used + None if the user specified a custom profile plot window. + """ + + if self._profileWindow is None: + self._profileMainWindow = ProfileMainWindow(self) + + # Actions + self.browseAction = qt.QAction( + icons.getQIcon('normal'), + 'Browsing Mode', None) + self.browseAction.setToolTip( + 'Enables zooming interaction mode') + self.browseAction.setCheckable(True) + self.browseAction.triggered[bool].connect(self._browseActionTriggered) + + self.hLineAction = qt.QAction( + icons.getQIcon('shape-horizontal'), + 'Horizontal Profile Mode', None) + self.hLineAction.setToolTip( + 'Enables horizontal profile selection mode') + self.hLineAction.setCheckable(True) + self.hLineAction.toggled[bool].connect(self._hLineActionToggled) + + self.vLineAction = qt.QAction( + icons.getQIcon('shape-vertical'), + 'Vertical Profile Mode', None) + self.vLineAction.setToolTip( + 'Enables vertical profile selection mode') + self.vLineAction.setCheckable(True) + self.vLineAction.toggled[bool].connect(self._vLineActionToggled) + + self.lineAction = qt.QAction( + icons.getQIcon('shape-diagonal'), + 'Free Line Profile Mode', None) + self.lineAction.setToolTip( + 'Enables line profile selection mode') + self.lineAction.setCheckable(True) + self.lineAction.toggled[bool].connect(self._lineActionToggled) + + self.clearAction = qt.QAction( + icons.getQIcon('profile-clear'), + 'Clear Profile', None) + self.clearAction.setToolTip( + 'Clear the profile Region of interest') + self.clearAction.setCheckable(False) + self.clearAction.triggered.connect(self.clearProfile) + + # ActionGroup + self.actionGroup = qt.QActionGroup(self) + self.actionGroup.addAction(self.browseAction) + self.actionGroup.addAction(self.hLineAction) + self.actionGroup.addAction(self.vLineAction) + self.actionGroup.addAction(self.lineAction) + + self.browseAction.setChecked(True) + + # Add actions to ToolBar + self.addAction(self.browseAction) + self.addAction(self.hLineAction) + self.addAction(self.vLineAction) + self.addAction(self.lineAction) + self.addAction(self.clearAction) + + # Add width spin box to toolbar + self.addWidget(qt.QLabel('W:')) + self.lineWidthSpinBox = qt.QSpinBox(self) + self.lineWidthSpinBox.setRange(0, 1000) + self.lineWidthSpinBox.setValue(1) + self.lineWidthSpinBox.valueChanged[int].connect( + self._lineWidthSpinBoxValueChangedSlot) + self.addWidget(self.lineWidthSpinBox) + + self.plot.sigInteractiveModeChanged.connect( + self._interactiveModeChanged) + + # Enable toolbar only if there is an active image + self.setEnabled(self.plot.getActiveImage(just_legend=True) is not None) + self.plot.sigActiveImageChanged.connect( + self._activeImageChanged) + + # listen to the profile window signals to clear profile polygon on close + if self.getProfileMainWindow() is not None: + self.getProfileMainWindow().sigClose.connect(self.clearProfile) + + @property + @deprecated(replacement="getProfilePlot", since_version="0.5.0") + def profileWindow(self): + return self.getProfilePlot() + + def getProfilePlot(self): + """Return plot widget in which the profile curve or the + profile image is plotted. + """ + if self.getProfileMainWindow() is not None: + return self.getProfileMainWindow().getPlot() + + # in case the user provided a custom plot for profiles + return self._profileWindow + + def getProfileMainWindow(self): + """Return window containing the profile curve widget. + This can return *None* if a custom profile plot window was + specified in the constructor. + """ + return self._profileMainWindow + + def _activeImageChanged(self, previous, legend): + """Handle active image change: toggle enabled toolbar, update curve""" + self.setEnabled(legend is not None) + if legend is not None: + # Update default profile color + activeImage = self.plot.getActiveImage() + if isinstance(activeImage, items.ColormapMixIn): + self._defaultOverlayColor = cursorColorForColormap( + activeImage.getColormap()['name']) + else: + self._defaultOverlayColor = 'black' + + self.updateProfile() + + def _lineWidthSpinBoxValueChangedSlot(self, value): + """Listen to ROI width widget to refresh ROI and profile""" + self.updateProfile() + + def _interactiveModeChanged(self, source): + """Handle plot interactive mode changed: + + If changed from elsewhere, disable drawing tool + """ + if source is not self: + self.browseAction.setChecked(True) + + def _hLineActionToggled(self, checked): + """Handle horizontal line profile action toggle""" + if checked: + self.plot.setInteractiveMode('draw', shape='hline', + color=None, source=self) + self.plot.sigPlotSignal.connect(self._plotWindowSlot) + else: + self.plot.sigPlotSignal.disconnect(self._plotWindowSlot) + + def _vLineActionToggled(self, checked): + """Handle vertical line profile action toggle""" + if checked: + self.plot.setInteractiveMode('draw', shape='vline', + color=None, source=self) + self.plot.sigPlotSignal.connect(self._plotWindowSlot) + else: + self.plot.sigPlotSignal.disconnect(self._plotWindowSlot) + + def _lineActionToggled(self, checked): + """Handle line profile action toggle""" + if checked: + self.plot.setInteractiveMode('draw', shape='line', + color=None, source=self) + self.plot.sigPlotSignal.connect(self._plotWindowSlot) + else: + self.plot.sigPlotSignal.disconnect(self._plotWindowSlot) + + def _browseActionTriggered(self, checked): + """Handle browse action mode triggered by user.""" + if checked: + self.clearProfile() + self.plot.setInteractiveMode('zoom', source=self) + if self.getProfileMainWindow() is not None: + self.getProfileMainWindow().hide() + + def _plotWindowSlot(self, event): + """Listen to Plot to handle drawing events to refresh ROI and profile. + """ + if event['event'] not in ('drawingProgress', 'drawingFinished'): + return + + checkedAction = self.actionGroup.checkedAction() + if checkedAction == self.hLineAction: + lineProjectionMode = 'X' + elif checkedAction == self.vLineAction: + lineProjectionMode = 'Y' + elif checkedAction == self.lineAction: + lineProjectionMode = 'D' + else: + return + + roiStart, roiEnd = event['points'][0], event['points'][1] + + self._roiInfo = roiStart, roiEnd, lineProjectionMode + self.updateProfile() + + @property + def overlayColor(self): + """The color to use for the ROI. + + If set to None (the default), the overlay color is adapted to the + active image colormap and changes if the active image colormap changes. + """ + return self._overlayColor or self._defaultOverlayColor + + @overlayColor.setter + def overlayColor(self, color): + self._overlayColor = color + self.updateProfile() + + def clearProfile(self): + """Remove profile curve and profile area.""" + self._roiInfo = None + self.updateProfile() + + def updateProfile(self): + """Update the displayed profile and profile ROI. + + This uses the current active image of the plot and the current ROI. + """ + image = self.plot.getActiveImage() + if image is None: + return + + # Clean previous profile area, and previous curve + self.plot.remove(self._POLYGON_LEGEND, kind='item') + self.getProfilePlot().clear() + self.getProfilePlot().setGraphTitle('') + self.getProfilePlot().setGraphXLabel('X') + self.getProfilePlot().setGraphYLabel('Y') + + self._createProfile(currentData=image.getData(copy=False), + origin=image.getOrigin(), + scale=image.getScale(), + colormap=None, # Not used for 2D data + z=image.getZValue()) + + def _createProfile(self, currentData, origin, scale, colormap, z): + """Create the profile line for the the given image. + + :param numpy.ndarray currentData: the image or the stack of images + on which we compute the profile + :param origin: (ox, oy) the offset from origin + :type origin: 2-tuple of float + :param scale: (sx, sy) the scale to use + :type scale: 2-tuple of float + :param dict colormap: The colormap to use + :param int z: The z layer of the image + """ + if self._roiInfo is None: + return + + profile, area, profileName, xLabel = createProfile( + roiInfo=self._roiInfo, + currentData=currentData, + origin=origin, + scale=scale, + lineWidth=self.lineWidthSpinBox.value()) + + self.getProfilePlot().setGraphTitle(profileName) + + dataIs3D = len(currentData.shape) > 2 + if dataIs3D: + self.getProfilePlot().addImage(profile, + legend=profileName, + xlabel=xLabel, + ylabel="Frame index (depth)", + colormap=colormap) + else: + coords = numpy.arange(len(profile[0]), dtype=numpy.float32) + self.getProfilePlot().addCurve(coords, + profile[0], + legend=profileName, + xlabel=xLabel, + color=self.overlayColor) + + self.plot.addItem(area[0], area[1], + legend=self._POLYGON_LEGEND, + color=self.overlayColor, + shape='polygon', fill=True, + replace=False, z=z + 1) + + self._showProfileMainWindow() + + def _showProfileMainWindow(self): + """If profile window was created by this toolbar, + try to avoid overlapping with the toolbar's parent window. + """ + profileMainWindow = self.getProfileMainWindow() + if profileMainWindow is not None: + winGeom = self.window().frameGeometry() + qapp = qt.QApplication.instance() + screenGeom = qapp.desktop().availableGeometry(self) + + spaceOnLeftSide = winGeom.left() + spaceOnRightSide = screenGeom.width() - winGeom.right() + + profileWindowWidth = profileMainWindow.frameGeometry().width() + if (profileWindowWidth < spaceOnRightSide or + spaceOnRightSide > spaceOnLeftSide): + # Place profile on the right + profileMainWindow.move(winGeom.right(), winGeom.top()) + else: + # Not enough place on the right, place profile on the left + profileMainWindow.move( + max(0, winGeom.left() - profileWindowWidth), winGeom.top()) + + profileMainWindow.show() + else: + self.getProfilePlot().show() + + def hideProfileWindow(self): + """Hide profile window. + """ + # this method is currently only used by StackView when the perspective + # is changed + if self.getProfileMainWindow() is not None: + self.getProfileMainWindow().hide() + + +class Profile3DToolBar(ProfileToolBar): + def __init__(self, parent=None, plot=None, title='Profile Selection'): + """QToolBar providing profile tools for an image or a stack of images. + + :param parent: the parent QWidget + :param plot: :class:`PlotWindow` instance on which to operate. + :param str title: See :class:`QToolBar`. + :param parent: See :class:`QToolBar`. + """ + # TODO: add param profileWindow (specify the plot used for profiles) + super(Profile3DToolBar, self).__init__(parent=parent, plot=plot, + title=title) + + self.profile3dAction = ProfileToolButton( + parent=self, plot=self.plot) + self.profile3dAction.computeProfileIn2D() + self.profile3dAction.setVisible(True) + self.addWidget(self.profile3dAction) + self.profile3dAction.sigDimensionChanged.connect(self._setProfileType) + + # create the 3D toolbar + self._profileType = None + self._setProfileType(2) + + def _setProfileType(self, dimensions): + """Set the profile type: "1D" for a curve (profile on a single image) + or "2D" for an image (profile on a stack of images). + + :param int dimensions: 1 for a "1D" profile or 2 for a "2D" profile + """ + # fixme this assumes that we created _profileMainWindow + self._profileType = "1D" if dimensions == 1 else "2D" + self.getProfileMainWindow().setProfileType(self._profileType) + self.updateProfile() + + def updateProfile(self): + """Method overloaded from :class:`ProfileToolBar`, + to pass the stack of images instead of just the active image. + + In 1D profile mode, use the regular parent method. + """ + if self._profileType == "1D": + super(Profile3DToolBar, self).updateProfile() + elif self._profileType == "2D": + stackData = self.plot.getCurrentView(copy=False, + returnNumpyArray=True) + if stackData is None: + return + self.plot.remove(self._POLYGON_LEGEND, kind='item') + self.getProfilePlot().clear() + self.getProfilePlot().setGraphTitle('') + self.getProfilePlot().setGraphXLabel('X') + self.getProfilePlot().setGraphYLabel('Y') + + self._createProfile(currentData=stackData[0], + origin=stackData[1]['origin'], + scale=stackData[1]['scale'], + colormap=stackData[1]['colormap'], + z=stackData[1]['z']) + else: + raise ValueError( + "Profile type must be 1D or 2D, not %s" % self._profileType) diff --git a/silx/gui/plot/ProfileMainWindow.py b/silx/gui/plot/ProfileMainWindow.py new file mode 100644 index 0000000..835de2c --- /dev/null +++ b/silx/gui/plot/ProfileMainWindow.py @@ -0,0 +1,99 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module contains a QMainWindow class used to display profile plots. +""" +from silx.gui import qt + + +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "21/02/2017" + + +class ProfileMainWindow(qt.QMainWindow): + """QMainWindow providing 2 plot widgets specialized in + 1D and 2D plotting, with different toolbars. + Only one of the plots is visible at any given time. + """ + sigProfileDimensionsChanged = qt.Signal(int) + """This signal is emitted when :meth:`setProfileDimensions` is called. + It carries the number of dimensions for the profile data (1 or 2). + It can be used to be notified that the profile plot widget has changed. + """ + + sigClose = qt.Signal() + """Emitted by :meth:`closeEvent` (e.g. when the window is closed + through the window manager's close icon).""" + + def __init__(self, parent=None): + qt.QMainWindow.__init__(self, parent=parent) + + self.setWindowTitle('Profile window') + # plots are created on demand, in self.setProfileDimensions() + self._plot1D = None + self._plot2D = None + # by default, profile is assumed to be a 1D curve + self._profileType = None + self.setProfileType("1D") + + def setProfileType(self, profileType): + """Set which profile plot widget (1D or 2D) is to be used + + :param str profileType: Type of profile data, + "1D" for a curve or "2D" for an image + """ + # import here to avoid circular import + from .PlotWindow import Plot1D, Plot2D # noqa + self._profileType = profileType + + if self._profileType == "1D": + if self._plot2D is not None: + self._plot2D.setParent(None) # necessary to avoid widget destruction + if self._plot1D is None: + self._plot1D = Plot1D() + self.setCentralWidget(self._plot1D) + elif self._profileType == "2D": + if self._plot1D is not None: + self._plot1D.setParent(None) # necessary to avoid widget destruction + if self._plot2D is None: + self._plot2D = Plot2D() + self.setCentralWidget(self._plot2D) + else: + raise ValueError("Profile type must be '1D' or '2D'") + + self.sigProfileDimensionsChanged.emit(profileType) + + def getPlot(self): + """Return the profile plot widget which is currently in use. + This can be the 2D profile plot or the 1D profile plot. + """ + if self._profileType == "2D": + return self._plot2D + else: + return self._plot1D + + def closeEvent(self, qCloseEvent): + self.sigClose.emit() + qCloseEvent.accept() diff --git a/silx/gui/plot/ScatterMaskToolsWidget.py b/silx/gui/plot/ScatterMaskToolsWidget.py new file mode 100644 index 0000000..793719d --- /dev/null +++ b/silx/gui/plot/ScatterMaskToolsWidget.py @@ -0,0 +1,529 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Widget providing a set of tools to draw masks on a PlotWidget. + +This widget is meant to work with a modified :class:`silx.gui.plot.PlotWidget` + +- :class:`ScatterMask`: Handle scatter mask update and history +- :class:`ScatterMaskToolsWidget`: GUI for :class:`ScatterMask` +- :class:`ScatterMaskToolsDockWidget`: DockWidget to integrate in :class:`PlotWindow` +""" + +from __future__ import division + +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "07/04/2017" + + +import math +import logging +import os +import numpy +import sys + +from .. import qt +from ...image import shapes + +from ._BaseMaskToolsWidget import BaseMask, BaseMaskToolsWidget, BaseMaskToolsDockWidget +from .Colors import cursorColorForColormap, rgba + + +_logger = logging.getLogger(__name__) + + +class ScatterMask(BaseMask): + """A 1D mask for scatter data. + """ + def __init__(self, scatter=None): + """ + + :param scatter: :class:`silx.gui.plot.items.Scatter` instance + """ + BaseMask.__init__(self, scatter) + + def _getXY(self): + x = self._dataItem.getXData(copy=False) + y = self._dataItem.getYData(copy=False) + return x, y + + def getDataValues(self): + """Return scatter data values as a 1D array. + + :rtype: 1D numpy.ndarray + """ + return self._dataItem.getValueData(copy=False) + + def save(self, filename, kind): + if kind == 'npy': + try: + numpy.save(filename, self.getMask(copy=False)) + except IOError: + raise RuntimeError("Mask file can't be written") + elif kind in ["csv", "txt"]: + try: + numpy.savetxt(filename, self.getMask(copy=False)) + except IOError: + raise RuntimeError("Mask file can't be written") + + def updatePoints(self, level, indices, mask=True): + """Mask/Unmask points with given indices. + + :param int level: Mask level to update. + :param indices: Sequence or 1D array of indices of points to be + updated + :param bool mask: True to mask (default), False to unmask. + """ + if mask: + self._mask[indices] = level + else: + # unmask only where mask level is the specified value + indices_stencil = numpy.zeros_like(self._mask, dtype=numpy.bool) + indices_stencil[indices] = True + self._mask[numpy.logical_and(self._mask == level, indices_stencil)] = 0 + self._notify() + + # update shapes + def updatePolygon(self, level, vertices, mask=True): + """Mask/Unmask a polygon of the given mask level. + + :param int level: Mask level to update. + :param vertices: Nx2 array of polygon corners as (y, x) or (row, col) + :param bool mask: True to mask (default), False to unmask. + """ + polygon = shapes.Polygon(vertices) + x, y = self._getXY() + + # TODO: this could be optimized if necessary + indices_in_polygon = [idx for idx in range(len(x)) if + polygon.is_inside(y[idx], x[idx])] + + self.updatePoints(level, indices_in_polygon, mask) + + def updateRectangle(self, level, y, x, height, width, mask=True): + """Mask/Unmask data inside a rectangle + + :param int level: Mask level to update. + :param float y: Y coordinate of bottom left corner of the rectangle + :param float x: X coordinate of bottom left corner of the rectangle + :param float height: + :param float width: + :param bool mask: True to mask (default), False to unmask. + """ + vertices = [(y, x), + (y + height, x), + (y + height, x + width), + (y, x + width)] + self.updatePolygon(level, vertices, mask) + + def updateDisk(self, level, cy, cx, radius, mask=True): + """Mask/Unmask a disk of the given mask level. + + :param int level: Mask level to update. + :param float cy: Disk center (y). + :param float cx: Disk center (x). + :param float radius: Radius of the disk in mask array unit + :param bool mask: True to mask (default), False to unmask. + """ + x, y = self._getXY() + stencil = (y - cy)**2 + (x - cx)**2 < radius**2 + self.updateStencil(level, stencil, mask) + + def updateLine(self, level, y0, x0, y1, x1, width, mask=True): + """Mask/Unmask points inside a rectangle defined by a line (two + end points) and a width. + + :param int level: Mask level to update. + :param float y0: Row of the starting point. + :param float x0: Column of the starting point. + :param float row1: Row of the end point. + :param float col1: Column of the end point. + :param float width: Width of the line. + :param bool mask: True to mask (default), False to unmask. + """ + # theta is the angle between the horizontal and the line + theta = math.atan((y1 - y0) / (x1 - x0)) if x1 - x0 else 0 + w_over_2_sin_theta = width / 2. * math.sin(theta) + w_over_2_cos_theta = width / 2. * math.cos(theta) + + vertices = [(y0 - w_over_2_cos_theta, x0 + w_over_2_sin_theta), + (y0 + w_over_2_cos_theta, x0 - w_over_2_sin_theta), + (y1 + w_over_2_cos_theta, x1 - w_over_2_sin_theta), + (y1 - w_over_2_cos_theta, x1 + w_over_2_sin_theta)] + + self.updatePolygon(level, vertices, mask) + + +class ScatterMaskToolsWidget(BaseMaskToolsWidget): + """Widget with tools for masking data points on a scatter in a + :class:`PlotWidget`.""" + + def __init__(self, parent=None, plot=None): + self._z = 2 # Mask layer in plot + self._data_scatter = None + """plot Scatter item for data""" + self._mask_scatter = None + """plot Scatter item for representing the mask""" + + self._mask = ScatterMask() + + super(ScatterMaskToolsWidget, self).__init__(parent, plot) + + self._initWidgets() + + def setSelectionMask(self, mask, copy=True): + """Set the mask to a new array. + + :param numpy.ndarray mask: The array to use for the mask. + :type mask: numpy.ndarray of uint8, C-contiguous. + Array of other types are converted. + :param bool copy: True (the default) to copy the array, + False to use it as is if possible. + :return: None if failed, shape of mask as 1-tuple if successful. + The mask can be cropped or padded to fit active scatter, + the returned shape is that of the scatter data. + """ + mask = numpy.array(mask, copy=False, dtype=numpy.uint8) + + if self._data_scatter.getXData(copy=False).shape == (0,) \ + or mask.shape == self._data_scatter.getXData(copy=False).shape: + self._mask.setMask(mask, copy=copy) + self._mask.commit() + return mask.shape + else: + raise ValueError("Mask does not have the same shape as the data") + + # Handle mask refresh on the plot + + def _updatePlotMask(self): + """Update mask image in plot""" + mask = self.getSelectionMask(copy=False) + if len(mask): + self.plot.addScatter(self._data_scatter.getXData(), + self._data_scatter.getYData(), + mask, + legend=self._maskName, + colormap=self._colormap, + z=self._z) + self._mask_scatter = self.plot._getItem(kind="scatter", + legend=self._maskName) + self._mask_scatter.setSymbolSize( + self._data_scatter.getSymbolSize() * 4.0 + ) + elif self.plot._getItem(kind="scatter", + legend=self._maskName) is not None: + self.plot.remove(self._maskName, kind='scatter') + + # track widget visibility and plot active image changes + + def showEvent(self, event): + try: + self.plot.sigActiveScatterChanged.disconnect( + self._activeScatterChangedAfterCare) + except (RuntimeError, TypeError): + pass + self._activeScatterChanged(None, None) # Init mask + enable/disable widget + self.plot.sigActiveScatterChanged.connect(self._activeScatterChanged) + + def hideEvent(self, event): + self.plot.sigActiveScatterChanged.disconnect(self._activeScatterChanged) + if not self.browseAction.isChecked(): + self.browseAction.trigger() # Disable drawing tool + + if len(self.getSelectionMask(copy=False)): + self.plot.sigActiveScatterChanged.connect( + self._activeScatterChangedAfterCare) + + def _activeScatterChangedAfterCare(self, previous, next): + """Check synchro of active scatter and mask when mask widget is hidden. + + If active image has no more the same size as the mask, the mask is + removed, otherwise it is adjusted to z. + """ + # check that content changed was the active scatter + activeScatter = self.plot._getActiveItem(kind="scatter") + + if activeScatter is None or activeScatter.getLegend() == self._maskName: + # No active scatter or active scatter is the mask... + self.plot.sigActiveScatterChanged.disconnect( + self._activeScatterChangedAfterCare) + else: + colormap = activeScatter.getColormap() + self._defaultOverlayColor = rgba(cursorColorForColormap(colormap['name'])) + self._setMaskColors(self.levelSpinBox.value(), + self.transparencySlider.value() / + self.transparencySlider.maximum()) + + self._z = activeScatter.getZValue() + 1 + self._data_scatter = activeScatter + if self._data_scatter.getXData(copy=False).shape != self.getSelectionMask(copy=False).shape: + # scatter has not the same size, remove mask and stop listening + if self.plot._getItem(kind="scatter", legend=self._maskName): + self.plot.remove(self._maskName, kind='scatter') + + self.plot.sigActiveScatterChanged.disconnect( + self._activeScatterChangedAfterCare) + else: + # Refresh in case z changed + self._mask.setDataItem(self._data_scatter) + self._updatePlotMask() + + def _activeScatterChanged(self, previous, next): + """Update widget and mask according to active scatter changes""" + activeScatter = self.plot._getActiveItem(kind="scatter") + + if activeScatter is None or activeScatter.getLegend() == self._maskName: + # No active scatter or active scatter is the mask... + self.setEnabled(False) + + self._data_scatter = None + self._mask.reset() + self._mask.commit() + + else: # There is an active scatter + self.setEnabled(True) + + colormap = activeScatter.getColormap() + self._defaultOverlayColor = rgba(cursorColorForColormap(colormap['name'])) + self._setMaskColors(self.levelSpinBox.value(), + self.transparencySlider.value() / + self.transparencySlider.maximum()) + + self._z = activeScatter.getZValue() + 1 + self._data_scatter = activeScatter + self._mask.setDataItem(self._data_scatter) + if self._data_scatter.getXData(copy=False).shape != self.getSelectionMask(copy=False).shape: + self._mask.reset(self._data_scatter.getXData(copy=False).shape) + self._mask.commit() + else: + # Refresh in case z changed + self._updatePlotMask() + + self._updateInteractiveMode() + + # Handle whole mask operations + + def load(self, filename): + """Load a mask from an image file. + + :param str filename: File name from which to load the mask + :raise Exception: An exception in case of failure + :raise RuntimeWarning: In case the mask was applied but with some + import changes to notice + """ + _, extension = os.path.splitext(filename) + extension = extension.lower()[1:] + if extension == "npy": + try: + mask = numpy.load(filename) + except IOError: + _logger.error("Can't load filename '%s'", filename) + _logger.debug("Backtrace", exc_info=True) + raise RuntimeError('File "%s" is not a numpy file.', + filename) + elif extension in ["txt", "csv"]: + try: + mask = numpy.loadtxt(filename) + except IOError: + _logger.error("Can't load filename '%s'", filename) + _logger.debug("Backtrace", exc_info=True) + raise RuntimeError('File "%s" is not a numpy txt file.', + filename) + else: + msg = "Extension '%s' is not supported." + raise RuntimeError(msg % extension) + + self.setSelectionMask(mask, copy=False) + + def _loadMask(self): + """Open load mask dialog""" + dialog = qt.QFileDialog(self) + dialog.setWindowTitle("Load Mask") + dialog.setModal(1) + filters = [ + 'NumPy binary file (*.npy)', + 'CSV text file (*.csv)', + ] + dialog.setNameFilters(filters) + dialog.setFileMode(qt.QFileDialog.ExistingFile) + dialog.setDirectory(self.maskFileDir) + if not dialog.exec_(): + dialog.close() + return + + filename = dialog.selectedFiles()[0] + dialog.close() + + self.maskFileDir = os.path.dirname(filename) + try: + self.load(filename) + # except RuntimeWarning as e: + # message = e.args[0] + # msg = qt.QMessageBox(self) + # msg.setIcon(qt.QMessageBox.Warning) + # msg.setText("Mask loaded but an operation was applied.\n" + message) + # msg.exec_() + except Exception as e: + message = e.args[0] + msg = qt.QMessageBox(self) + msg.setIcon(qt.QMessageBox.Critical) + msg.setText("Cannot load mask from file. " + message) + msg.exec_() + + def _saveMask(self): + """Open Save mask dialog""" + dialog = qt.QFileDialog(self) + dialog.setWindowTitle("Save Mask") + dialog.setModal(1) + filters = [ + 'NumPy binary file (*.npy)', + 'CSV text file (*.csv)', + ] + dialog.setNameFilters(filters) + dialog.setFileMode(qt.QFileDialog.AnyFile) + dialog.setAcceptMode(qt.QFileDialog.AcceptSave) + dialog.setDirectory(self.maskFileDir) + if not dialog.exec_(): + dialog.close() + return + + # convert filter name to extension name with the . + extension = dialog.selectedNameFilter().split()[-1][2:-1] + filename = dialog.selectedFiles()[0] + dialog.close() + + if not filename.lower().endswith(extension): + filename += extension + + if os.path.exists(filename): + try: + os.remove(filename) + except IOError: + msg = qt.QMessageBox(self) + msg.setIcon(qt.QMessageBox.Critical) + msg.setText("Cannot save.\n" + "Input Output Error: %s" % (sys.exc_info()[1])) + msg.exec_() + return + + self.maskFileDir = os.path.dirname(filename) + try: + self.save(filename, extension[1:]) + except Exception as e: + msg = qt.QMessageBox(self) + msg.setIcon(qt.QMessageBox.Critical) + msg.setText("Cannot save file %s\n%s" % (filename, e.args[0])) + msg.exec_() + + def resetSelectionMask(self): + """Reset the mask""" + self._mask.reset( + shape=self._data_scatter.getXData(copy=False).shape) + self._mask.commit() + + def _plotDrawEvent(self, event): + """Handle draw events from the plot""" + if (self._drawingMode is None or + event['event'] not in ('drawingProgress', 'drawingFinished')): + return + + if not len(self._data_scatter.getXData(copy=False)): + return + + level = self.levelSpinBox.value() + + if (self._drawingMode == 'rectangle' and + event['event'] == 'drawingFinished'): + doMask = self._isMasking() + + self._mask.updateRectangle( + level, + y=event['y'], + x=event['x'], + height=abs(event['height']), + width=abs(event['width']), + mask=doMask) + self._mask.commit() + + elif (self._drawingMode == 'polygon' and + event['event'] == 'drawingFinished'): + doMask = self._isMasking() + vertices = event['points'] + vertices = vertices.astype(numpy.int)[:, (1, 0)] # (y, x) + self._mask.updatePolygon(level, vertices, doMask) + self._mask.commit() + + elif self._drawingMode == 'pencil': + doMask = self._isMasking() + # convert from plot to array coords + x, y = event['points'][-1] + brushSize = self.pencilSpinBox.value() + + if self._lastPencilPos != (y, x): + if self._lastPencilPos is not None: + # Draw the line + self._mask.updateLine( + level, + self._lastPencilPos[0], self._lastPencilPos[1], + y, x, + brushSize, + doMask) + + # Draw the very first, or last point + self._mask.updateDisk(level, y, x, brushSize / 2., doMask) + + if event['event'] == 'drawingFinished': + self._mask.commit() + self._lastPencilPos = None + else: + self._lastPencilPos = y, x + + def _loadRangeFromColormapTriggered(self): + """Set range from active scatter colormap range""" + if self._data_scatter is not None: + # Update thresholds according to colormap + colormap = self._data_scatter.getColormap() + if colormap['autoscale']: + min_ = numpy.nanmin(self._data_scatter.getValueData(copy=False)) + max_ = numpy.nanmax(self._data_scatter.getValueData(copy=False)) + else: + min_, max_ = colormap['vmin'], colormap['vmax'] + self.minLineEdit.setText(str(min_)) + self.maxLineEdit.setText(str(max_)) + + +class ScatterMaskToolsDockWidget(BaseMaskToolsDockWidget): + """:class:`ScatterMaskToolsWidget` embedded in a QDockWidget. + + For integration in a :class:`PlotWindow`. + + :param parent: See :class:`QDockWidget` + :param plot: The PlotWidget this widget is operating on + :paran str name: The title of this widget + """ + def __init__(self, parent=None, plot=None, name='Mask'): + super(ScatterMaskToolsDockWidget, self).__init__(parent, name) + self.setWidget(ScatterMaskToolsWidget(plot=plot)) + self.widget().sigMaskChanged.connect(self._emitSigMaskChanged) diff --git a/silx/gui/plot/StackView.py b/silx/gui/plot/StackView.py new file mode 100644 index 0000000..9bb0cf0 --- /dev/null +++ b/silx/gui/plot/StackView.py @@ -0,0 +1,1033 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""QWidget displaying a 3D volume as a stack of 2D images. + +The :class:`StackView` class implements this widget. + +Basic usage of :class:`StackView` is through the following methods: + +- :meth:`StackView.getColormap`, :meth:`StackView.setColormap` to update the + default colormap to use and update the currently displayed image. +- :meth:`StackView.setStack` to update the displayed image. + +The :class:`StackView` uses :class:`PlotWindow` and also +exposes a subset of the :class:`silx.gui.plot.Plot` API for further control +(plot title, axes labels, ...). + +The :class:`StackViewMainWindow` class implements a widget that adds a status +bar displaying the 3D index and the value under the mouse cursor. + +Example:: + + import numpy + import sys + from silx.gui import qt + from silx.gui.plot.StackView import StackViewMainWindow + + + app = qt.QApplication(sys.argv[1:]) + + # synthetic data, stack of 100 images of size 200x300 + mystack = numpy.fromfunction( + lambda i, j, k: numpy.sin(i/15.) + numpy.cos(j/4.) + 2 * numpy.sin(k/6.), + (100, 200, 300) + ) + + + sv = StackViewMainWindow() + sv.setColormap("jet", autoscale=True) + sv.setStack(mystack) + sv.setLabels(["1st dim (0-99)", "2nd dim (0-199)", + "3rd dim (0-299)"]) + sv.show() + + app.exec_() + +""" + +__authors__ = ["P. Knobel", "H. Payno"] +__license__ = "MIT" +__date__ = "20/01/2017" + +import numpy + +try: + import h5py +except ImportError: + h5py = None + +from silx.gui import qt +from .. import icons +from . import items, PlotWindow, PlotActions +from .Colors import cursorColorForColormap +from .PlotTools import LimitsToolBar +from .Profile import Profile3DToolBar +from ..widgets.FrameBrowser import HorizontalSliderWithBrowser + +from silx.utils.array_like import DatasetView, ListOfImages +from silx.math import calibration + + +class StackView(qt.QMainWindow): + """Stack view widget, to display and browse through stack of + images. + + The profile tool can be switched to "3D" mode, to compute the profile + on each image of the stack (not only the active image currently displayed) + and display the result as a slice. + + :param QWidget parent: the Qt parent, or None + :param backend: The backend to use for the plot (default: matplotlib). + See :class:`.Plot` for the list of supported backend. + :type backend: str or :class:`BackendBase.BackendBase` + :param bool resetzoom: Toggle visibility of reset zoom action. + :param bool autoScale: Toggle visibility of axes autoscale actions. + :param bool logScale: Toggle visibility of axes log scale actions. + :param bool grid: Toggle visibility of grid mode action. + :param bool colormap: Toggle visibility of colormap action. + :param bool aspectRatio: Toggle visibility of aspect ratio button. + :param bool yInverted: Toggle visibility of Y axis direction button. + :param bool copy: Toggle visibility of copy action. + :param bool save: Toggle visibility of save action. + :param bool print_: Toggle visibility of print action. + :param bool control: True to display an Options button with a sub-menu + to show legends, toggle crosshair and pan with arrows. + (Default: False) + :param position: True to display widget with (x, y) mouse position + (Default: False). + It also supports a list of (name, funct(x, y)->value) + to customize the displayed values. + See :class:`silx.gui.plot.PlotTools.PositionInfo`. + :param bool mask: Toggle visibilty of mask action. + """ + # Qt signals + valueChanged = qt.Signal(object, object, object) + """Signals that the data value under the cursor has changed. + + It provides: row, column, data value. + """ + + sigPlaneSelectionChanged = qt.Signal(int) + """Signal emitted when there is a change is perspective/displayed axes. + + It provides the perspective as an integer, with the following meaning: + + - 0: axis Y is the 2nd dimension, axis X is the 3rd dimension + - 1: axis Y is the 1st dimension, axis X is the 3rd dimension + - 2: axis Y is the 1st dimension, axis X is the 2nd dimension + """ + + sigStackChanged = qt.Signal(int) + """Signal emitted when the stack is changed. + This happens when a new volume is loaded, or when the current volume + is transposed (change in perspective). + + The signal provides the size (number of pixels) of the stack. + This will be 0 if the stack is cleared, else it will be a positive + integer. + """ + + def __init__(self, parent=None, resetzoom=True, backend=None, + autoScale=False, logScale=False, grid=False, + colormap=True, aspectRatio=True, yinverted=True, + copy=True, save=True, print_=True, control=False, + position=None, mask=True): + qt.QMainWindow.__init__(self, parent) + if parent is not None: + # behave as a widget + self.setWindowFlags(qt.Qt.Widget) + else: + self.setWindowTitle('StackView') + + self._stack = None + """Loaded stack, as a 3D array, a 3D dataset or a list of 2D arrays.""" + self.__transposed_view = None + """View on :attr:`_stack` with the axes sorted, to have + the orthogonal dimension first""" + self._perspective = 0 + """Orthogonal dimension (depth) in :attr:`_stack`""" + + self.__imageLegend = '__StackView__image' + str(id(self)) + self.__autoscaleCmap = False + """Flag to disable/enable colormap auto-scaling + based on the min/max values of the entire 3D volume""" + self.__dimensionsLabels = ["Dimension 0", "Dimension 1", + "Dimension 2"] + """These labels are displayed on the X and Y axes. + :meth:`setLabels` updates this attribute.""" + + self._first_stack_dimension = 0 + """Used for dimension labels and combobox""" + + central_widget = qt.QWidget(self) + + self._plot = PlotWindow(parent=central_widget, backend=backend, + resetzoom=resetzoom, autoScale=autoScale, + logScale=logScale, grid=grid, + curveStyle=False, colormap=colormap, + aspectRatio=aspectRatio, yInverted=yinverted, + copy=copy, save=save, print_=print_, + control=control, position=position, + roi=False, mask=mask) + self.sigInteractiveModeChanged = self._plot.sigInteractiveModeChanged + self.sigActiveImageChanged = self._plot.sigActiveImageChanged + self.sigPlotSignal = self._plot.sigPlotSignal + + self._plot.profile = Profile3DToolBar(parent=self._plot, + plot=self) + self._plot.addToolBar(self._plot.profile) + self._plot.setGraphXLabel('Columns') + self._plot.setGraphYLabel('Rows') + self._plot.sigPlotSignal.connect(self._plotCallback) + + self.__planeSelection = PlanesWidget(self._plot) + self.__planeSelection.sigPlaneSelectionChanged.connect(self.__setPerspective) + + self._browser_label = qt.QLabel("Image index (Dim0):") + + self._browser = HorizontalSliderWithBrowser(central_widget) + self._browser.valueChanged[int].connect(self.__updateFrameNumber) + self._browser.setEnabled(False) + + layout = qt.QGridLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self._plot, 0, 0, 1, 3) + layout.addWidget(self.__planeSelection, 1, 0) + layout.addWidget(self._browser_label, 1, 1) + layout.addWidget(self._browser, 1, 2) + + central_widget.setLayout(layout) + self.setCentralWidget(central_widget) + + # clear profile lines when the perspective changes (plane browsed changed) + self.__planeSelection.sigPlaneSelectionChanged.connect( + self._plot.profile.getProfilePlot().clear) + self.__planeSelection.sigPlaneSelectionChanged.connect( + self._plot.profile.clearProfile) + + def setOptionVisible(self, isVisible): + """ + Set the visibility of the browsing options. + + :param bool isVisible: True to have the options visible, else False + """ + self._browser.setVisible(isVisible) + self.__planeSelection.setVisible(isVisible) + + def _plotCallback(self, eventDict): + """Callback for plot events. + + Emit :attr:`valueChanged` signal, with (x, y, value) tuple of the + cursor location in the plot.""" + if eventDict['event'] == 'mouseMoved': + activeImage = self.getActiveImage() + if activeImage is not None: + data = activeImage.getData() + height, width = data.shape + + # Get corresponding coordinate in image + origin = activeImage.getOrigin() + scale = activeImage.getScale() + x = int((eventDict['x'] - origin[0]) / scale[0]) + y = int((eventDict['y'] - origin[1]) / scale[1]) + + if 0 <= x < width and 0 <= y < height: + self.valueChanged.emit(float(x), float(y), + data[y][x]) + else: + self.valueChanged.emit(float(x), float(y), + None) + + def __setPerspective(self, perspective): + """Function called when the browsed/orthogonal dimension changes. + Updates :attr:`_perspective`, transposes data, updates the plot, + emits :attr:`sigPlaneSelectionChanged` and :attr:`sigStackChanged`. + + :param int perspective: the new browsed dimension + """ + if perspective == self._perspective: + return + else: + if perspective > 2 or perspective < 0: + raise ValueError( + "Perspective must be 0, 1 or 2, not %s" % perspective) + + self._perspective = perspective + self.__createTransposedView() + self.__updateFrameNumber(self._browser.value()) + self._plot.resetZoom() + self.__updatePlotLabels() + self._browser_label.setText("Image index (Dim%d):" % + (self._first_stack_dimension + perspective)) + + self.sigPlaneSelectionChanged.emit(perspective) + self.sigStackChanged.emit(self._stack.size if + self._stack is not None else 0) + + def __updatePlotLabels(self): + """Update plot axes labels depending on perspective""" + y, x = (1, 2) if self._perspective == 0 else \ + (0, 2) if self._perspective == 1 else (0, 1) + self.setGraphXLabel(self.__dimensionsLabels[x]) + self.setGraphYLabel(self.__dimensionsLabels[y]) + + def __createTransposedView(self): + """Create the new view on the stack depending on the perspective + (set orthogonal axis browsed on the viewer as first dimension) + """ + assert self._stack is not None + assert 0 <= self._perspective < 3 + + # ensure we have the stack encapsulated in an array like object + # having a transpose() method + if isinstance(self._stack, numpy.ndarray): + self.__transposed_view = self._stack + + elif h5py is not None and isinstance(self._stack, h5py.Dataset) or \ + isinstance(self._stack, DatasetView): + self.__transposed_view = DatasetView(self._stack) + + elif isinstance(self._stack, ListOfImages): + self.__transposed_view = ListOfImages(self._stack) + + # transpose the array like object if necessary + if self._perspective == 1: + self.__transposed_view = self.__transposed_view.transpose((1, 0, 2)) + elif self._perspective == 2: + self.__transposed_view = self.__transposed_view.transpose((2, 0, 1)) + + self._browser.setRange(0, self.__transposed_view.shape[0] - 1) + self._browser.setValue(0) + + def setFrameNumber(self, number): + """Set the frame selection to a specific value\ + + :param int number: Number of the frame + """ + self._browser.setValue(number) + + def __updateFrameNumber(self, index): + """Update the current image. + + :param index: index of the frame to be displayed + """ + assert self.__transposed_view is not None + self._plot.addImage(self.__transposed_view[index, :, :], + origin=self._getImageOrigin(), + scale=self._getImageScale(), + legend=self.__imageLegend, + resetzoom=False, replace=False) + self._plot.setGraphTitle("Image z=%g" % self._getImageZ(index)) + + def _set3DScaleAndOrigin(self, calibrations): + """Set scale and origin for all 3 axes, to be used when plotting + an image. + + See setStack for parameter documentation + """ + if calibrations is None: + self.calibrations3D = (calibration.NoCalibration(), + calibration.NoCalibration(), + calibration.NoCalibration()) + else: + self.calibrations3D = [] + for calib in calibrations: + if hasattr(calib, "__len__") and len(calib) == 2: + calib = calibration.LinearCalibration(calib[0], calib[1]) + elif calib is None: + calib = calibration.NoCalibration() + elif not isinstance(calib, calibration.AbstractCalibration): + raise TypeError("calibration must be a 2-tuple, None or" + + " an instance of an AbstractCalibration " + + "subclass") + self.calibrations3D.append(calib) + + def _getXYZCalibs(self): + xy_dims = [0, 1, 2] + xy_dims.remove(self._perspective) + + xcalib = self.calibrations3D[max(xy_dims)] + ycalib = self.calibrations3D[min(xy_dims)] + zcalib = self.calibrations3D[self._perspective] + + return xcalib, ycalib, zcalib + + def _getImageScale(self): + """ + :return: 2-tuple (XScale, YScale) for current image view + """ + xcalib, ycalib, _zcalib = self._getXYZCalibs() + return xcalib.get_slope(), ycalib.get_slope() + + def _getImageOrigin(self): + """ + :return: 2-tuple (XOrigin, YOrigin) for current image view + """ + xcalib, ycalib, _zcalib = self._getXYZCalibs() + return xcalib(0), ycalib(0) + + def _getImageZ(self, index): + """ + :param idx: 0-based image index in the stack + :return: calibrated Z value corresponding to the image idx + """ + _xcalib, _ycalib, zcalib = self._getXYZCalibs() + return zcalib(index) + + # public API + def setStack(self, stack, perspective=0, reset=True, calibrations=None): + """Set the 3D stack. + + The perspective parameter is used to define which dimension of the 3D + array is to be used as frame index. The lowest remaining dimension + number is the row index of the displayed image (Y axis), and the highest + remaining dimension is the column index (X axis). + + :param stack: 3D stack, or `None` to clear plot. + :type stack: 3D numpy.ndarray, or 3D h5py.Dataset, or list/tuple of 2D + numpy arrays, or None. + :param int perspective: Dimension for the frame index: 0, 1 or 2. + By default, the dimension for the image index is the first + dimension of the 3D stack (``perspective=0``). + :param bool reset: Whether to reset zoom or not. + :param calibrations: Sequence of 3 calibration objects for each axis. + These objects can be a subclass of :class:`AbstractCalibration`, + or 2-tuples *(a, b)* where *a* is the y-intercept and *b* is the + slope of a linear calibration (:math:`x \mapsto a + b x`) + """ + if stack is None: + self.clear() + self.sigStackChanged.emit(0) + return + + self._set3DScaleAndOrigin(calibrations) + + # stack as list of 2D arrays: must be converted into an array_like + if not isinstance(stack, numpy.ndarray): + if h5py is None or not isinstance(stack, h5py.Dataset): + try: + assert hasattr(stack, "__len__") + for img in stack: + assert hasattr(img, "shape") + assert len(img.shape) == 2 + except AssertionError: + raise ValueError( + "Stack must be a 3D array/dataset or a list of " + + "2D arrays.") + stack = ListOfImages(stack) + + assert len(stack.shape) == 3, "data must be 3D" + + self._stack = stack + self.__createTransposedView() + + if perspective != self._perspective: + self.__setPerspective(perspective) + + # This call to setColormap redefines the meaning of autoscale + # for 3D volume: take global min/max rather than frame min/max + if self.__autoscaleCmap: + self.setColormap(autoscale=True) + + # init plot + self._plot.addImage(self.__transposed_view[0, :, :], + legend=self.__imageLegend, + colormap=self.getColormap(), + origin=self._getImageOrigin(), + scale=self._getImageScale(), + resetzoom=False) + self._plot.setActiveImage(self.__imageLegend) + self._plot.setGraphTitle("Image z=%g" % self._getImageZ(0)) + self.__updatePlotLabels() + + if reset: + self._plot.resetZoom() + + # enable and init browser + self._browser.setEnabled(True) + + if perspective != self._perspective: + self.__planeSelection.setPerspective(perspective) + # this causes self.__setPerspective to be called, which emits + # sigStackChanged and sigPlaneSelectionChanged + + else: + self.sigStackChanged.emit(stack.size) + + def getStack(self, copy=True, returnNumpyArray=False): + """Get the original stack, as a 3D array or dataset. + + The output has the form: [data, params] + where params is a dictionary containing display parameters. + + :param bool copy: If True (default), then the object is copied + and returned as a numpy array. + Else, a reference to original data is returned, if possible. + If the original data is not a numpy array and parameter + returnNumpyArray is True, a copy will be made anyway. + :param bool returnNumpyArray: If True, the returned object is + guaranteed to be a numpy array. + :return: 3D stack and parameters. + :rtype: (numpy.ndarray, dict) + """ + image = self.getActiveImage() + if image is None: + return None + + if isinstance(image, items.ColormapMixIn): + colormap = image.getColormap() + else: + colormap = None + + params = { + 'info': image.getInfo(), + 'origin': image.getOrigin(), + 'scale': image.getScale(), + 'z': image.getZValue(), + 'selectable': image.isSelectable(), + 'draggable': image.isDraggable(), + 'colormap': colormap, + 'xlabel': image.getXLabel(), + 'ylabel': image.getYLabel(), + } + if returnNumpyArray or copy: + return numpy.array(self._stack, copy=copy), params + + # if a list of 2D arrays was cast into a ListOfImages, + # return the original list + if isinstance(self._stack, ListOfImages): + return self._stack.images, params + + return self._stack, params + + def getCurrentView(self, copy=True, returnNumpyArray=False): + """Get the stack, as it is currently displayed. + + The first index of the returned stack is always the frame + index. If the perspective has been changed in the widget since the + data was first loaded, this will be reflected in the order of the + dimensions of the returned object. + + The output has the form: [data, params] + where params is a dictionary containing display parameters. + + :param bool copy: If True (default), then the object is copied + and returned as a numpy array. + Else, a reference to original data is returned, if possible. + If the original data is not a numpy array and parameter + `returnNumpyArray` is `True`, a copy will be made anyway. + :param bool returnNumpyArray: If `True`, the returned object is + guaranteed to be a numpy array. + :return: 3D stack and parameters. + :rtype: (numpy.ndarray, dict) + """ + image = self.getActiveImage() + if image is None: + return None + + if isinstance(image, items.ColormapMixIn): + colormap = image.getColormap() + else: + colormap = None + + params = { + 'info': image.getInfo(), + 'origin': image.getOrigin(), + 'scale': image.getScale(), + 'z': image.getZValue(), + 'selectable': image.isSelectable(), + 'draggable': image.isDraggable(), + 'colormap': colormap, + 'xlabel': image.getXLabel(), + 'ylabel': image.getYLabel(), + } + if returnNumpyArray or copy: + return numpy.array(self.__transposed_view, copy=copy), params + return self.__transposed_view, params + + def getActiveImage(self, just_legend=False): + """Returns the currently active image object. + + It returns None in case of not having an active image. + + :param bool just_legend: True to get the legend of the image, + False (the default) to get the image data and info. + Note: :class:`StackView` uses the same legend for all frames. + :return: legend or image object + :rtype: str or list or None + """ + return self._plot.getActiveImage(just_legend=just_legend) + + def clear(self): + """Clear the widget: + + - clear the plot + - clear the loaded data volume + """ + self._stack = None + self.__transposed_view = None + self._perspective = 0 + self._browser.setEnabled(False) + self._plot.clear() + + def resetZoom(self): + """Reset the plot limits to the bounds of the data and redraw the plot. + """ + self._plot.resetZoom() + + def getGraphTitle(self): + """Return the plot main title as a str.""" + return self._plot.getGraphTitle() + + def setGraphTitle(self, title=""): + """Set the plot main title. + + :param str title: Main title of the plot (default: '') + """ + return self._plot.setGraphTitle(title) + + def setLabels(self, labels=None): + """Set the labels to be displayed on the plot axes. + + You must provide a sequence of 3 strings, corresponding to the 3 + dimensions of the original data volume. + The proper label will automatically be selected for each plot axis + when the volume is rotated (when different axes are selected as the + X and Y axes). + + :param list(str) labels: 3 labels corresponding to the 3 dimensions + of the data volumes. + """ + + default_labels = ["Dimension %d" % self._first_stack_dimension, + "Dimension %d" % (self._first_stack_dimension + 1), + "Dimension %d" % (self._first_stack_dimension + 2)] + if labels is None: + new_labels = default_labels + else: + # filter-out None + new_labels = [] + for i, label in enumerate(labels): + new_labels.append(label or default_labels[i]) + + self.__dimensionsLabels = new_labels + self.__updatePlotLabels() + + def getGraphXLabel(self): + """Return the current horizontal axis label as a str.""" + return self._plot.getGraphXLabel() + + def setGraphXLabel(self, label=None): + """Set the plot horizontal axis label. + + :param str label: The horizontal axis label + """ + if label is None: + label = self.__dimensionsLabels[1 if self._perspective == 2 else 2] + self._plot.setGraphXLabel(label) + + def getGraphYLabel(self, axis='left'): + """Return the current vertical axis label as a str. + + :param str axis: The Y axis for which to get the label (left or right) + """ + return self._plot.getGraphYLabel(axis) + + def setGraphYLabel(self, label=None, axis='left'): + """Set the vertical axis label on the plot. + + :param str label: The Y axis label + :param str axis: The Y axis for which to set the label (left or right) + """ + if label is None: + label = self.__dimensionsLabels[1 if self._perspective == 0 else 0] + self._plot.setGraphYLabel(label, axis) + + def setYAxisInverted(self, flag=True): + """Set the Y axis orientation. + + :param bool flag: True for Y axis going from top to bottom, + False for Y axis going from bottom to top + """ + self._plot.setYAxisInverted(flag) + + def isYAxisInverted(self): + """Return True if Y axis goes from top to bottom, False otherwise.""" + return self._backend.isYAxisInverted() + + def getSupportedColormaps(self): + """Get the supported colormap names as a tuple of str. + + The list should at least contain and start by: + ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue') + """ + return self._plot.getSupportedColormaps() + + def getColormap(self): + """Get the current colormap description. + + :return: A description of the current colormap. + See :meth:`setColormap` for details. + :rtype: dict + """ + # "default" colormap used by addImage when image is added without + # specifying a special colormap + return self._plot.getDefaultColormap() + + def setColormap(self, colormap=None, normalization=None, + autoscale=None, vmin=None, vmax=None, colors=None): + """Set the colormap and update active image. + + Parameters that are not provided are taken from the current colormap. + + The colormap parameter can also be a dict with the following keys: + + - *name*: string. The colormap to use: + 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'. + - *normalization*: string. The mapping to use for the colormap: + either 'linear' or 'log'. + - *autoscale*: bool. Whether to use autoscale (True) or range + provided by keys + 'vmin' and 'vmax' (False). + - *vmin*: float. The minimum value of the range to use if 'autoscale' + is False. + - *vmax*: float. The maximum value of the range to use if 'autoscale' + is False. + - *colors*: optional. Nx3 or Nx4 array of float in [0, 1] or uint8. + List of RGB or RGBA colors to use (only if name is None) + + :param colormap: Name of the colormap in + 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'. + Or the description of the colormap as a dict. + :type colormap: dict or str. + :param str normalization: Colormap mapping: 'linear' or 'log'. + :param bool autoscale: Whether to use autoscale or [vmin, vmax] range. + Default value of autoscale is True if data is a numpy array, + False if data is a h5py dataset. + :param float vmin: The minimum value of the range to use if + 'autoscale' is False. + :param float vmax: The maximum value of the range to use if + 'autoscale' is False. + :param numpy.ndarray colors: Only used if name is None. + Custom colormap colors as Nx3 or Nx4 RGB or RGBA arrays + """ + cmapDict = self.getColormap() + + if isinstance(colormap, dict): + # Support colormap parameter as a dict + errmsg = "If colormap is provided as a dict, all other parameters" + errmsg += " must not be specified when calling setColormap" + assert normalization is None, errmsg + assert autoscale is None, errmsg + assert vmin is None, errmsg + assert vmax is None, errmsg + assert colors is None, errmsg + cmapDict.update(colormap) + + else: + if colormap is not None: + cmapDict['name'] = colormap + if normalization is not None: + cmapDict['normalization'] = normalization + if colors is not None: + cmapDict['colors'] = colors + + # Default meaning of autoscale is to reset min and max + # each time a new image is added to the plot. + # We want to use min and max of global volume, + # and not change them when browsing slides + cmapDict['autoscale'] = False + + if autoscale is None: + # set default + autoscale = False + # TODO: assess cost of computing min/max for large 3D array + # if isinstance(self._stack, numpy.ndarray): + # autoscale = True + # else: # h5py.Dataset + # autoscale = False + elif autoscale and isinstance(self._stack, h5py.Dataset): + # h5py dataset has no min()/max() methods + raise RuntimeError( + "Cannot auto-scale colormap for a h5py dataset") + else: + autoscale = autoscale + self.__autoscaleCmap = autoscale + if autoscale and (self._stack is not None): + cmapDict['vmin'] = self._stack.min() + cmapDict['vmax'] = self._stack.max() + else: + if vmin is not None: + cmapDict['vmin'] = vmin + if vmax is not None: + cmapDict['vmax'] = vmax + + cursorColor = cursorColorForColormap(cmapDict['name']) + self._plot.setInteractiveMode('zoom', color=cursorColor) + + self._plot.setDefaultColormap(cmapDict) + + # Update active image colormap + activeImage = self._plot.getActiveImage() + if isinstance(activeImage, items.ColormapMixIn): + activeImage.setColormap(self.getColormap()) + + def isKeepDataAspectRatio(self): + """Returns whether the plot is keeping data aspect ratio or not.""" + return self._plot.isKeepDataAspectRatio() + + def setKeepDataAspectRatio(self, flag=True): + """Set whether the plot keeps data aspect ratio or not. + + :param bool flag: True to respect data aspect ratio + """ + self._plot.setKeepDataAspectRatio(flag) + + def getProfileToolbar(self): + """Profile tools attached to this plot + + See :class:`silx.gui.plot.Profile.Profile3DToolBar` + """ + return self._plot.profile + + def getProfileWindow1D(self): + """Plot window used to display 1D profile curve. + + :return: :class:`Plot1D` + """ + return self._plot.profile.getProfileWindow1D() + + def getProfileWindow2D(self): + """Plot window used to display 2D profile image. + + :return: :class:`Plot2D` + """ + return self._plot.profile.getProfileWindow2D() + + # kind of private methods, but needed by Profile + def remove(self, legend=None, + kind=('curve', 'image', 'item', 'marker')): + """See :meth:`Plot.Plot.remove`""" + self._plot.remove(legend, kind) + + def setInteractiveMode(self, *args, **kwargs): + """ + See :meth:`Plot.Plot.setInteractiveMode` + """ + self._plot.setInteractiveMode(*args, **kwargs) + + def addItem(self, *args, **kwargs): + """ + See :meth:`Plot.Plot.addItem` + """ + self._plot.addItem(*args, **kwargs) + + def setFirstStackDimension(self, first_stack_dimension): + """When viewing the last 3 dimensions of an n-D array (n>3), you can + use this method to change the text in the combobox. + + For instance, for a 7-D array, first stack dim is 4, so the default + "Dim1-Dim2" text should be replaced with "Dim5-Dim6" (dimensions + numbers are 0-based). + + :param int first_stack_dim: First stack dimension (n-3) when viewing the + last 3 dimensions of an n-D array. + """ + old_state = self.__planeSelection.blockSignals(True) + self.__planeSelection.setFirstStackDimension(first_stack_dimension) + self.__planeSelection.blockSignals(old_state) + self._first_stack_dimension = first_stack_dimension + self._browser_label.setText("Image index (Dim%d):" % first_stack_dimension) + + +class PlanesWidget(qt.QWidget): + """Widget for the plane/perspective selection + + :param parent: the parent QWidget + """ + sigPlaneSelectionChanged = qt.Signal(int) + + def __init__(self, parent): + super(PlanesWidget, self).__init__(parent) + + self.setSizePolicy(qt.QSizePolicy.Minimum, qt.QSizePolicy.Minimum) + layout0 = qt.QHBoxLayout() + self.setLayout(layout0) + layout0.setContentsMargins(0, 0, 0, 0) + + layout0.addWidget(qt.QLabel("Axes selection:")) + + # By default, the first dimension (dim0) is the frame index/depth/z, + # the second dimension is the image row number/y axis + # and the third dimension is the image column index/x axis + + # 1 + # | 0 + # |/__2 + self.qcbAxisSelection = qt.QComboBox(self) + self._setCBChoices(first_stack_dimension=0) + self.qcbAxisSelection.currentIndexChanged[int].connect( + self.__planeSelectionChanged) + + layout0.addWidget(self.qcbAxisSelection) + + def __planeSelectionChanged(self, idx): + """Callback function when the combobox selection changes + + idx is the dimension number orthogonal to the slice plane, + following the convention: + + - slice plane Dim1-Dim2: perspective 0 + - slice plane Dim0-Dim2: perspective 1 + - slice plane Dim0-Dim1: perspective 2 + """ + self.sigPlaneSelectionChanged.emit(idx) + + def _setCBChoices(self, first_stack_dimension): + self.qcbAxisSelection.clear() + + dim1dim2 = 'Dim%d-Dim%d' % (first_stack_dimension + 1, + first_stack_dimension + 2) + dim0dim2 = 'Dim%d-Dim%d' % (first_stack_dimension, + first_stack_dimension + 2) + dim0dim1 = 'Dim%d-Dim%d' % (first_stack_dimension, + first_stack_dimension + 1) + + self.qcbAxisSelection.addItem(icons.getQIcon("cube-front"), dim1dim2) + self.qcbAxisSelection.addItem(icons.getQIcon("cube-bottom"), dim0dim2) + self.qcbAxisSelection.addItem(icons.getQIcon("cube-left"), dim0dim1) + + def setFirstStackDimension(self, first_stack_dim): + """When viewing the last 3 dimensions of an n-D array (n>3), you can + use this method to change the text in the combobox. + + For instance, for a 7-D array, first stack dim is 4, so the default + "Dim1-Dim2" text should be replaced with "Dim5-Dim6" (dimensions + numbers are 0-based). + + :param int first_stack_dim: First stack dimension (n-3) when viewing the + last 3 dimensions of an n-D array. + """ + self._setCBChoices(first_stack_dim) + + def setPerspective(self, perspective): + """Update the combobox selection. + + - slice plane Dim1-Dim2: perspective 0 + - slice plane Dim0-Dim2: perspective 1 + - slice plane Dim0-Dim1: perspective 2 + + :param perspective: Orthogonal dimension number (0, 1, or 2) + """ + self.qcbAxisSelection.setCurrentIndex(perspective) + + +class StackViewMainWindow(StackView): + """This class is a :class:`StackView` with a menu, an additional toolbar + to set the plot limits, and a status bar to display the value and 3D + index of the data samples hovered by the mouse cursor. + + :param QWidget parent: Parent widget, or None + """ + def __init__(self, parent=None): + self._dataInfo = None + super(StackViewMainWindow, self).__init__(parent) + self.setWindowFlags(qt.Qt.Window) + + # Add toolbars and status bar + self.addToolBar(qt.Qt.BottomToolBarArea, + LimitsToolBar(plot=self._plot)) + + self.statusBar() + + menu = self.menuBar().addMenu('File') + menu.addAction(self._plot.saveAction) + menu.addAction(self._plot.printAction) + menu.addSeparator() + action = menu.addAction('Quit') + action.triggered[bool].connect(qt.QApplication.instance().quit) + + menu = self.menuBar().addMenu('Edit') + menu.addAction(self._plot.copyAction) + menu.addSeparator() + menu.addAction(self._plot.resetZoomAction) + menu.addAction(self._plot.colormapAction) + menu.addAction(PlotActions.KeepAspectRatioAction(self._plot, self)) + menu.addAction(PlotActions.YAxisInvertedAction(self._plot, self)) + + menu = self.menuBar().addMenu('Profile') + menu.addAction(self._plot.profile.browseAction) + menu.addAction(self._plot.profile.hLineAction) + menu.addAction(self._plot.profile.vLineAction) + menu.addAction(self._plot.profile.lineAction) + menu.addSeparator() + menu.addAction(self._plot.profile.clearAction) + self._plot.profile.profile3dAction.computeProfileIn2D() + menu.addMenu(self._plot.profile.profile3dAction.menu()) + + # Connect to StackView's signal + self.valueChanged.connect(self._statusBarSlot) + + def _statusBarSlot(self, x, y, value): + """Update status bar with coordinates/value from plots.""" + # todo (after implementing calibration): + # - use floats for (x, y, z) + # - display both indices (dim0, dim1, dim2) and (x, y, z) + msg = "Cursor out of range" + if x is not None and y is not None: + img_idx = self._browser.value() + + if self._perspective == 0: + dim0, dim1, dim2 = img_idx, int(y), int(x) + elif self._perspective == 1: + dim0, dim1, dim2 = int(y), img_idx, int(x) + elif self._perspective == 2: + dim0, dim1, dim2 = int(y), int(x), img_idx + + msg = 'Position: (%d, %d, %d)' % (dim0, dim1, dim2) + if value is not None: + msg += ', Value: %g' % value + if self._dataInfo is not None: + msg = self._dataInfo + ', ' + msg + + self.statusBar().showMessage(msg) + + def setStack(self, stack, *args, **kwargs): + """Set the displayed stack. + + See :meth:`StackView.setStack` for details. + """ + if hasattr(stack, 'dtype') and hasattr(stack, 'shape'): + assert len(stack.shape) == 3 + nframes, height, width = stack.shape + self._dataInfo = 'Data: %dx%dx%d (%s)' % (nframes, height, width, + str(stack.dtype)) + self.statusBar().showMessage(self._dataInfo) + else: + self._dataInfo = None + + # Set the new stack in StackView widget + super(StackViewMainWindow, self).setStack(stack, *args, **kwargs) + self.setStatusBar(None) diff --git a/silx/gui/plot/_BaseMaskToolsWidget.py b/silx/gui/plot/_BaseMaskToolsWidget.py new file mode 100644 index 0000000..91bbe1c --- /dev/null +++ b/silx/gui/plot/_BaseMaskToolsWidget.py @@ -0,0 +1,1138 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module is a collection of base classes used in modules +:mod:`.MaskToolsWidget` (images) and :mod:`.ScatterMaskToolsWidget` +""" +from __future__ import division + + +__authors__ = ["T. Vincent", "P. Knobel"] +__license__ = "MIT" +__date__ = "20/04/2017" + +import os + +import numpy + +from silx.gui import qt, icons +from silx.gui.plot.Colors import rgba + + +class BaseMask(qt.QObject): + """Base class for :class:`ImageMask` and :class:`ScatterMask` + + A mask field with update operations. + + A mask is an array of the same shape as some underlying data. The mask + array stores integer values in the range 0-255, to allow for 254 levels + of mask (value 0 is reserved for unmasked data). + + The mask is updated using spatial selection methods: data located inside + a selected area is masked with a specified mask level. + + """ + + sigChanged = qt.Signal() + """Signal emitted when the mask has changed""" + + sigUndoable = qt.Signal(bool) + """Signal emitted when undo becomes possible/impossible""" + + sigRedoable = qt.Signal(bool) + """Signal emitted when redo becomes possible/impossible""" + + def __init__(self, dataItem=None): + self.historyDepth = 10 + """Maximum number of operation stored in history list for undo""" + # Init lists for undo/redo + self._history = [] + self._redo = [] + + # Store the mask + self._mask = numpy.array((), dtype=numpy.uint8) + + # Store the plot item to be masked + self._dataItem = None + if dataItem is not None: + self.setDataItem(dataItem) + self.reset(self.getDataValues().shape) + + super(BaseMask, self).__init__() + + def setDataItem(self, item): + """Set a data item + + :param item: A plot item, subclass of :class:`silx.gui.plot.items.Item` + :return: + """ + self._dataItem = item + + def getDataValues(self): + """Return data values, as a numpy array with the same shape + as the mask. + + This method must be implemented in a subclass, as the way of + accessing data depends on the data item passed to :meth:`setDataItem` + + :return: Data values associated with the data item. + :rtype: numpy.ndarray + """ + raise NotImplementedError("To be implemented in subclass") + + def _notify(self): + """Notify of mask change.""" + self.sigChanged.emit() + + def getMask(self, copy=True): + """Get the current mask as a numpy array. + + :param bool copy: True (default) to get a copy of the mask. + If False, the returned array MUST not be modified. + :return: The array of the mask with dimension of the data to be masked. + :rtype: numpy.ndarray of uint8 + """ + return numpy.array(self._mask, copy=copy) + + def setMask(self, mask, copy=True): + """Set the mask to a new array. + + :param numpy.ndarray mask: The array to use for the mask. + :type mask: numpy.ndarray of uint8, C-contiguous. + Array of other types are converted. + :param bool copy: True (the default) to copy the array, + False to use it as is if possible. + """ + self._mask = numpy.array(mask, copy=copy, order='C', dtype=numpy.uint8) + self._notify() + + # History control + def resetHistory(self): + """Reset history""" + self._history = [numpy.array(self._mask, copy=True)] + self._redo = [] + self.sigUndoable.emit(False) + self.sigRedoable.emit(False) + + def commit(self): + """Append the current mask to history if changed""" + if (not self._history or self._redo or + not numpy.all(numpy.equal(self._mask, self._history[-1]))): + if self._redo: + self._redo = [] # Reset redo as a new action as been performed + self.sigRedoable[bool].emit(False) + + while len(self._history) >= self.historyDepth: + self._history.pop(0) + self._history.append(numpy.array(self._mask, copy=True)) + + if len(self._history) == 2: + self.sigUndoable.emit(True) + + def undo(self): + """Restore previous mask if any""" + if len(self._history) > 1: + self._redo.append(self._history.pop()) + self._mask = numpy.array(self._history[-1], copy=True) + self._notify() # Do not store this change in history + + if len(self._redo) == 1: # First redo + self.sigRedoable.emit(True) + if len(self._history) == 1: # Last value in history + self.sigUndoable.emit(False) + + def redo(self): + """Restore previously undone modification if any""" + if self._redo: + self._mask = self._redo.pop() + self._history.append(numpy.array(self._mask, copy=True)) + self._notify() + + if not self._redo: # No more redo + self.sigRedoable.emit(False) + if len(self._history) == 2: # Something to undo + self.sigUndoable.emit(True) + + # Whole mask operations + + def clear(self, level): + """Set all values of the given mask level to 0. + + :param int level: Value of the mask to set to 0. + """ + assert 0 < level < 256 + self._mask[self._mask == level] = 0 + self._notify() + + def invert(self, level): + """Invert mask of the given mask level. + + 0 values become level and level values become 0. + + :param int level: The level to invert. + """ + assert 0 < level < 256 + masked = self._mask == level + self._mask[self._mask == 0] = level + self._mask[masked] = 0 + self._notify() + + def reset(self, shape=None): + """Reset the mask to zero and change its shape. + + :param shape: Shape of the new mask with the correct dimensionality + with regards to the data dimensionality, + or None to have an empty mask + :type shape: tuple of int + """ + if shape is None: + # assume dimensionality never changes + shape = (0, ) * len(self._mask.shape) # empty array + shapeChanged = (shape != self._mask.shape) + self._mask = numpy.zeros(shape, dtype=numpy.uint8) + if shapeChanged: + self.resetHistory() + + self._notify() + + # To be implemented + def save(self, filename, kind): + """Save current mask in a file + + :param str filename: The file where to save to mask + :param str kind: The kind of file to save (e.g 'npy') + :raise Exception: Raised if the file writing fail + """ + raise NotImplementedError("To be implemented in subclass") + + # update thresholds + def updateStencil(self, level, stencil, mask=True): + """Mask/Unmask points from boolean mask: all elements that are True + in the boolean mask are set to ``level`` (if ``mask=True``) or 0 + (if ``mask=False``) + + :param int level: Mask level to update. + :param stencil: Boolean mask. + :type stencil: numpy.array of same dimension as the mask + :param bool mask: True to mask (default), False to unmask. + """ + if mask: + self._mask[stencil] = level + else: + self._mask[numpy.logical_and(self._mask == level, stencil)] = 0 + self._notify() + + def updateBelowThreshold(self, level, threshold, mask=True): + """Mask/unmask all points whose values are below a threshold. + + :param int level: + :param float threshold: Threshold + :param bool mask: True to mask (default), False to unmask. + """ + self.updateStencil(level, + self.getDataValues() < threshold, + mask) + + def updateBetweenThresholds(self, level, min_, max_, mask=True): + """Mask/unmask all points whose values are in a range. + + :param int level: + :param float min_: Lower threshold + :param float max_: Upper threshold + :param bool mask: True to mask (default), False to unmask. + """ + stencil = numpy.logical_and(min_ <= self.getDataValues(), + self.getDataValues() <= max_) + self.updateStencil(level, stencil, mask) + + def updateAboveThreshold(self, level, threshold, mask=True): + """Mask/unmask all points whose values are above a threshold. + + :param int level: Mask level to update. + :param float threshold: Threshold. + :param bool mask: True to mask (default), False to unmask. + """ + self.updateStencil(level, + self.getDataValues() > threshold, + mask) + + def updateNotFinite(self, level, mask=True): + """Mask/unmask all points whose values are not finite. + + :param int level: Mask level to update. + :param bool mask: True to mask (default), False to unmask. + """ + self.updateStencil(level, + numpy.logical_not(numpy.isfinite(self.getDataValues())), + mask) + + # Drawing operations: + def updateRectangle(self, level, row, col, height, width, mask=True): + """Mask/Unmask data inside a rectangle, with the given mask level. + + :param int level: Mask level to update, in range 1-255. + :param row: Starting row/y of the rectangle + :param col: Starting column/x of the rectangle + :param height: + :param width: + :param bool mask: True to mask (default), False to unmask. + """ + raise NotImplementedError("To be implemented in subclass") + + def updatePolygon(self, level, vertices, mask=True): + """Mask/Unmask data inside a polygon, with the given mask level. + + :param int level: Mask level to update. + :param vertices: Nx2 array of polygon corners as (row, col) / (y, x) + :param bool mask: True to mask (default), False to unmask. + """ + raise NotImplementedError("To be implemented in subclass") + + def updatePoints(self, level, rows, cols, mask=True): + """Mask/Unmask points with given coordinates. + + :param int level: Mask level to update. + :param rows: Rows/ordinates (y) of selected points + :type rows: 1D numpy.ndarray + :param cols: Columns/abscissa (x) of selected points + :type cols: 1D numpy.ndarray + :param bool mask: True to mask (default), False to unmask. + """ + raise NotImplementedError("To be implemented in subclass") + + def updateDisk(self, level, crow, ccol, radius, mask=True): + """Mask/Unmask data located inside a disk of the given mask level. + + :param int level: Mask level to update. + :param crow: Disk center row/ordinate (y). + :param ccol: Disk center column/abscissa. + :param float radius: Radius of the disk in mask array unit + :param bool mask: True to mask (default), False to unmask. + """ + raise NotImplementedError("To be implemented in subclass") + + def updateLine(self, level, row0, col0, row1, col1, width, mask=True): + """Mask/Unmask a line of the given mask level. + + :param int level: Mask level to update. + :param row0: Row/y of the starting point. + :param col0: Column/x of the starting point. + :param row1: Row/y of the end point. + :param col1: Column/x of the end point. + :param width: Width of the line in mask array unit. + :param bool mask: True to mask (default), False to unmask. + """ + raise NotImplementedError("To be implemented in subclass") + + +class BaseMaskToolsWidget(qt.QWidget): + """Base class for :class:`MaskToolsWidget` (image mask) and + :class:`scatterMaskToolsWidget`""" + + sigMaskChanged = qt.Signal() + _maxLevelNumber = 255 + + def __init__(self, parent=None, plot=None): + # register if the user as force a color for the corresponding mask level + self._defaultColors = numpy.ones((self._maxLevelNumber + 1), dtype=numpy.bool) + # overlays colors set by the user + self._overlayColors = numpy.zeros((self._maxLevelNumber + 1, 3), dtype=numpy.float32) + + self._plot = plot + self._maskName = '__MASK_TOOLS_%d' % id(self) # Legend of the mask + + self._colormap = { + 'name': None, + 'normalization': 'linear', + 'autoscale': False, + 'vmin': 0, 'vmax': self._maxLevelNumber, + 'colors': None} + self._defaultOverlayColor = rgba('gray') # Color of the mask + self._setMaskColors(1, 0.5) + + self._mask.sigChanged.connect(self._updatePlotMask) + self._mask.sigChanged.connect(self._emitSigMaskChanged) + + self._drawingMode = None # Store current drawing mode + self._lastPencilPos = None + self._multipleMasks = 'exclusive' + + super(BaseMaskToolsWidget, self).__init__(parent) + + self._maskFileDir = qt.QDir.home().absolutePath() + self.plot.sigInteractiveModeChanged.connect( + self._interactiveModeChanged) + + def _emitSigMaskChanged(self): + """Notify mask changes""" + self.sigMaskChanged.emit() + + def getSelectionMask(self, copy=True): + """Get the current mask as a numpy array. + + :param bool copy: True (default) to get a copy of the mask. + If False, the returned array MUST not be modified. + :return: The array of the mask with dimension of the 'active' plot item. + If there is no active image or scatter, an empty array is + returned. + :rtype: numpy.ndarray of uint8 + """ + return self._mask.getMask(copy=copy) + + def multipleMasks(self): + """Return the current mode of multiple masks support. + + See :meth:`setMultipleMasks` + """ + return self._multipleMasks + + def setMultipleMasks(self, mode): + """Set the mode of multiple masks support. + + Available modes: + + - 'single': Edit a single level of mask + - 'exclusive': Supports to 256 levels of non overlapping masks + + :param str mode: The mode to use + """ + assert mode in ('exclusive', 'single') + if mode != self._multipleMasks: + self._multipleMasks = mode + self.levelWidget.setVisible(self._multipleMasks != 'single') + self.clearAllBtn.setVisible(self._multipleMasks != 'single') + + @property + def maskFileDir(self): + """The directory from which to load/save mask from/to files.""" + if not os.path.isdir(self._maskFileDir): + self._maskFileDir = qt.QDir.home().absolutePath() + return self._maskFileDir + + @maskFileDir.setter + def maskFileDir(self, maskFileDir): + self._maskFileDir = str(maskFileDir) + + @property + def plot(self): + """The :class:`.PlotWindow` this widget is attached to.""" + return self._plot + + def setDirection(self, direction=qt.QBoxLayout.LeftToRight): + """Set the direction of the layout of the widget + + :param direction: QBoxLayout direction + """ + self.layout().setDirection(direction) + + def _initWidgets(self): + """Create widgets""" + layout = qt.QBoxLayout(qt.QBoxLayout.LeftToRight) + layout.addWidget(self._initMaskGroupBox()) + layout.addWidget(self._initDrawGroupBox()) + layout.addWidget(self._initThresholdGroupBox()) + layout.addStretch(1) + self.setLayout(layout) + + @staticmethod + def _hboxWidget(*widgets, **kwargs): + """Place widgets in widget with horizontal layout + + :param widgets: Widgets to position horizontally + :param bool stretch: True for trailing stretch (default), + False for no trailing stretch + :return: A QWidget with a QHBoxLayout + """ + stretch = kwargs.get('stretch', True) + + layout = qt.QHBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + for widget in widgets: + layout.addWidget(widget) + if stretch: + layout.addStretch(1) + widget = qt.QWidget() + widget.setLayout(layout) + return widget + + def _initTransparencyWidget(self): + """ Init the mask transparency widget """ + transparencyWidget = qt.QWidget(self) + grid = qt.QGridLayout() + grid.setContentsMargins(0, 0, 0, 0) + self.transparencySlider = qt.QSlider(qt.Qt.Horizontal, parent=transparencyWidget) + self.transparencySlider.setRange(3, 10) + self.transparencySlider.setValue(8) + self.transparencySlider.setToolTip( + 'Set the transparency of the mask display') + self.transparencySlider.valueChanged.connect(self._updateColors) + grid.addWidget(qt.QLabel('Display:', parent=transparencyWidget), 0, 0) + grid.addWidget(self.transparencySlider, 0, 1, 1, 3) + grid.addWidget(qt.QLabel('<small><b>Transparent</b></small>', parent=transparencyWidget), 1, 1) + grid.addWidget(qt.QLabel('<small><b>Opaque</b></small>', parent=transparencyWidget), 1, 3) + transparencyWidget.setLayout(grid) + return transparencyWidget + + def _initMaskGroupBox(self): + """Init general mask operation widgets""" + + # Mask level + self.levelSpinBox = qt.QSpinBox() + self.levelSpinBox.setRange(1, self._maxLevelNumber) + self.levelSpinBox.setToolTip( + 'Choose which mask level is edited.\n' + 'A mask can have up to 255 non-overlapping levels.') + self.levelSpinBox.valueChanged[int].connect(self._updateColors) + self.levelWidget = self._hboxWidget(qt.QLabel('Mask level:'), + self.levelSpinBox) + # Transparency + self.transparencyWidget = self._initTransparencyWidget() + + # Buttons group + invertBtn = qt.QPushButton('Invert') + invertBtn.setShortcut(qt.Qt.CTRL + qt.Qt.Key_I) + invertBtn.setToolTip('Invert current mask <b>%s</b>' % + invertBtn.shortcut().toString()) + invertBtn.clicked.connect(self._handleInvertMask) + + clearBtn = qt.QPushButton('Clear') + clearBtn.setShortcut(qt.QKeySequence.Delete) + clearBtn.setToolTip('Clear current mask level <b>%s</b>' % + clearBtn.shortcut().toString()) + clearBtn.clicked.connect(self._handleClearMask) + + invertClearWidget = self._hboxWidget( + invertBtn, clearBtn, stretch=False) + + undoBtn = qt.QPushButton('Undo') + undoBtn.setShortcut(qt.QKeySequence.Undo) + undoBtn.setToolTip('Undo last mask change <b>%s</b>' % + undoBtn.shortcut().toString()) + self._mask.sigUndoable.connect(undoBtn.setEnabled) + undoBtn.clicked.connect(self._mask.undo) + + redoBtn = qt.QPushButton('Redo') + redoBtn.setShortcut(qt.QKeySequence.Redo) + redoBtn.setToolTip('Redo last undone mask change <b>%s</b>' % + redoBtn.shortcut().toString()) + self._mask.sigRedoable.connect(redoBtn.setEnabled) + redoBtn.clicked.connect(self._mask.redo) + + undoRedoWidget = self._hboxWidget(undoBtn, redoBtn, stretch=False) + + self.clearAllBtn = qt.QPushButton('Clear all') + self.clearAllBtn.setToolTip('Clear all mask levels') + self.clearAllBtn.clicked.connect(self.resetSelectionMask) + + loadBtn = qt.QPushButton('Load...') + loadBtn.clicked.connect(self._loadMask) + + saveBtn = qt.QPushButton('Save...') + saveBtn.clicked.connect(self._saveMask) + + self.loadSaveWidget = self._hboxWidget(loadBtn, saveBtn, stretch=False) + + layout = qt.QVBoxLayout() + layout.addWidget(self.levelWidget) + layout.addWidget(self.transparencyWidget) + layout.addWidget(invertClearWidget) + layout.addWidget(undoRedoWidget) + layout.addWidget(self.clearAllBtn) + layout.addWidget(self.loadSaveWidget) + layout.addStretch(1) + + maskGroup = qt.QGroupBox('Mask') + maskGroup.setLayout(layout) + return maskGroup + + def _initDrawGroupBox(self): + """Init drawing tools widgets""" + layout = qt.QVBoxLayout() + + # Draw tools + self.browseAction = qt.QAction( + icons.getQIcon('normal'), 'Browse', None) + self.browseAction.setShortcut(qt.QKeySequence(qt.Qt.Key_B)) + self.browseAction.setToolTip( + 'Disables drawing tools, enables zooming interaction mode' + ' <b>B</b>') + self.browseAction.setCheckable(True) + self.browseAction.triggered.connect(self._activeBrowseMode) + self.addAction(self.browseAction) + + self.rectAction = qt.QAction( + icons.getQIcon('shape-rectangle'), 'Rectangle selection', None) + self.rectAction.setToolTip( + 'Rectangle selection tool: (Un)Mask a rectangular region <b>R</b>') + self.rectAction.setShortcut(qt.QKeySequence(qt.Qt.Key_R)) + self.rectAction.setCheckable(True) + self.rectAction.triggered.connect(self._activeRectMode) + self.addAction(self.rectAction) + + self.polygonAction = qt.QAction( + icons.getQIcon('shape-polygon'), 'Polygon selection', None) + self.polygonAction.setShortcut(qt.QKeySequence(qt.Qt.Key_S)) + self.polygonAction.setToolTip( + 'Polygon selection tool: (Un)Mask a polygonal region <b>S</b><br>' + 'Left-click to place polygon corners<br>' + 'Right-click to place the last corner') + self.polygonAction.setCheckable(True) + self.polygonAction.triggered.connect(self._activePolygonMode) + self.addAction(self.polygonAction) + + self.pencilAction = qt.QAction( + icons.getQIcon('draw-pencil'), 'Pencil tool', None) + self.pencilAction.setShortcut(qt.QKeySequence(qt.Qt.Key_P)) + self.pencilAction.setToolTip( + 'Pencil tool: (Un)Mask using a pencil <b>P</b>') + self.pencilAction.setCheckable(True) + self.pencilAction.triggered.connect(self._activePencilMode) + self.addAction(self.polygonAction) + + self.drawActionGroup = qt.QActionGroup(self) + self.drawActionGroup.setExclusive(True) + self.drawActionGroup.addAction(self.browseAction) + self.drawActionGroup.addAction(self.rectAction) + self.drawActionGroup.addAction(self.polygonAction) + self.drawActionGroup.addAction(self.pencilAction) + + self.browseAction.setChecked(True) + + self.drawButtons = {} + for action in self.drawActionGroup.actions(): + btn = qt.QToolButton() + btn.setDefaultAction(action) + self.drawButtons[action.text()] = btn + container = self._hboxWidget(*self.drawButtons.values()) + layout.addWidget(container) + + # Mask/Unmask radio buttons + maskRadioBtn = qt.QRadioButton('Mask') + maskRadioBtn.setToolTip( + 'Drawing masks with current level. Press <b>Ctrl</b> to unmask') + maskRadioBtn.setChecked(True) + + unmaskRadioBtn = qt.QRadioButton('Unmask') + unmaskRadioBtn.setToolTip( + 'Drawing unmasks with current level. Press <b>Ctrl</b> to mask') + + self.maskStateGroup = qt.QButtonGroup() + self.maskStateGroup.addButton(maskRadioBtn, 1) + self.maskStateGroup.addButton(unmaskRadioBtn, 0) + + self.maskStateWidget = self._hboxWidget(maskRadioBtn, unmaskRadioBtn) + layout.addWidget(self.maskStateWidget) + + # Connect mask state widget visibility with browse action + self.maskStateWidget.setHidden(self.browseAction.isChecked()) + self.browseAction.toggled[bool].connect( + self.maskStateWidget.setHidden) + + # Pencil settings + self.pencilSetting = self._createPencilSettings(None) + self.pencilSetting.setVisible(False) + layout.addWidget(self.pencilSetting) + + layout.addStretch(1) + + drawGroup = qt.QGroupBox('Draw tools') + drawGroup.setLayout(layout) + return drawGroup + + def _createPencilSettings(self, parent=None): + pencilSetting = qt.QWidget(parent) + + self.pencilSpinBox = qt.QSpinBox(parent=pencilSetting) + self.pencilSpinBox.setRange(1, 1024) + pencilToolTip = """Set pencil drawing tool size in pixels of the image + on which to make the mask.""" + self.pencilSpinBox.setToolTip(pencilToolTip) + + self.pencilSlider = qt.QSlider(qt.Qt.Horizontal, parent=pencilSetting) + self.pencilSlider.setRange(1, 50) + self.pencilSlider.setToolTip(pencilToolTip) + + pencilLabel = qt.QLabel('Pencil size:', parent=pencilSetting) + + layout = qt.QGridLayout() + layout.addWidget(pencilLabel, 0, 0) + layout.addWidget(self.pencilSpinBox, 0, 1) + layout.addWidget(self.pencilSlider, 1, 1) + pencilSetting.setLayout(layout) + + self.pencilSpinBox.valueChanged.connect(self._pencilWidthChanged) + self.pencilSlider.valueChanged.connect(self._pencilWidthChanged) + + return pencilSetting + + def _initThresholdGroupBox(self): + """Init thresholding widgets""" + layout = qt.QVBoxLayout() + + # Thresholing + + self.belowThresholdAction = qt.QAction( + icons.getQIcon('plot-roi-below'), 'Mask below threshold', None) + self.belowThresholdAction.setToolTip( + 'Mask image where values are below given threshold') + self.belowThresholdAction.setCheckable(True) + self.belowThresholdAction.triggered[bool].connect( + self._belowThresholdActionTriggered) + + self.betweenThresholdAction = qt.QAction( + icons.getQIcon('plot-roi-between'), 'Mask within range', None) + self.betweenThresholdAction.setToolTip( + 'Mask image where values are within given range') + self.betweenThresholdAction.setCheckable(True) + self.betweenThresholdAction.triggered[bool].connect( + self._betweenThresholdActionTriggered) + + self.aboveThresholdAction = qt.QAction( + icons.getQIcon('plot-roi-above'), 'Mask above threshold', None) + self.aboveThresholdAction.setToolTip( + 'Mask image where values are above given threshold') + self.aboveThresholdAction.setCheckable(True) + self.aboveThresholdAction.triggered[bool].connect( + self._aboveThresholdActionTriggered) + + self.thresholdActionGroup = qt.QActionGroup(self) + self.thresholdActionGroup.setExclusive(False) + self.thresholdActionGroup.addAction(self.belowThresholdAction) + self.thresholdActionGroup.addAction(self.betweenThresholdAction) + self.thresholdActionGroup.addAction(self.aboveThresholdAction) + self.thresholdActionGroup.triggered.connect( + self._thresholdActionGroupTriggered) + + self.loadColormapRangeAction = qt.QAction( + icons.getQIcon('view-refresh'), 'Set min-max from colormap', None) + self.loadColormapRangeAction.setToolTip( + 'Set min and max values from current colormap range') + self.loadColormapRangeAction.setCheckable(False) + self.loadColormapRangeAction.triggered.connect( + self._loadRangeFromColormapTriggered) + + widgets = [] + for action in self.thresholdActionGroup.actions(): + btn = qt.QToolButton() + btn.setDefaultAction(action) + widgets.append(btn) + + spacer = qt.QWidget() + spacer.setSizePolicy(qt.QSizePolicy.Expanding, + qt.QSizePolicy.Preferred) + widgets.append(spacer) + + loadColormapRangeBtn = qt.QToolButton() + loadColormapRangeBtn.setDefaultAction(self.loadColormapRangeAction) + widgets.append(loadColormapRangeBtn) + + container = self._hboxWidget(*widgets, stretch=False) + layout.addWidget(container) + + form = qt.QFormLayout() + + self.minLineEdit = qt.QLineEdit() + self.minLineEdit.setText('0') + self.minLineEdit.setValidator(qt.QDoubleValidator()) + self.minLineEdit.setEnabled(False) + form.addRow('Min:', self.minLineEdit) + + self.maxLineEdit = qt.QLineEdit() + self.maxLineEdit.setText('0') + self.maxLineEdit.setValidator(qt.QDoubleValidator()) + self.maxLineEdit.setEnabled(False) + form.addRow('Max:', self.maxLineEdit) + + self.applyMaskBtn = qt.QPushButton('Apply mask') + self.applyMaskBtn.clicked.connect(self._maskBtnClicked) + self.applyMaskBtn.setEnabled(False) + form.addRow(self.applyMaskBtn) + + self.maskNanBtn = qt.QPushButton('Mask not finite values') + self.maskNanBtn.setToolTip('Mask Not a Number and infinite values') + self.maskNanBtn.clicked.connect(self._maskNotFiniteBtnClicked) + form.addRow(self.maskNanBtn) + + thresholdWidget = qt.QWidget() + thresholdWidget.setLayout(form) + layout.addWidget(thresholdWidget) + + layout.addStretch(1) + + self.thresholdGroup = qt.QGroupBox('Threshold') + self.thresholdGroup.setLayout(layout) + return self.thresholdGroup + + # track widget visibility and plot active image changes + + def changeEvent(self, event): + """Reset drawing action when disabling widget""" + if (event.type() == qt.QEvent.EnabledChange and + not self.isEnabled() and + not self.browseAction.isChecked()): + self.browseAction.trigger() # Disable drawing tool + + def save(self, filename, kind): + """Save current mask in a file + + :param str filename: The file where to save to mask + :param str kind: The kind of file to save in 'edf', 'tif', 'npy' + :raise Exception: Raised if the process fails + """ + self._mask.save(filename, kind) + + def getCurrentMaskColor(self): + """Returns the color of the current selected level. + + :rtype: A tuple or a python array + """ + currentLevel = self.levelSpinBox.value() + if self._defaultColors[currentLevel]: + return self._defaultOverlayColor + else: + return self._overlayColors[currentLevel].tolist() + + def _setMaskColors(self, level, alpha): + """Set-up the mask colormap to highlight current mask level. + + :param int level: The mask level to highlight + :param float alpha: Alpha level of mask in [0., 1.] + """ + assert 0 < level <= self._maxLevelNumber + + colors = numpy.empty((self._maxLevelNumber + 1, 4), dtype=numpy.float32) + + # Set color + colors[:, :3] = self._defaultOverlayColor[:3] + + # check if some colors has been directly set by the user + mask = numpy.equal(self._defaultColors, False) + colors[mask, :3] = self._overlayColors[mask, :3] + + # Set alpha + colors[:, -1] = alpha / 2. + + # Set highlighted level color + colors[level, 3] = alpha + + # Set no mask level + colors[0] = (0., 0., 0., 0.) + + self._colormap['colors'] = colors + + def resetMaskColors(self, level=None): + """Reset the mask color at the given level to be defaultColors + + :param level: + The index of the mask for which we want to reset the color. + If none we will reset color for all masks. + """ + if level is None: + self._defaultColors[level] = True + else: + self._defaultColors[:] = True + + self._updateColors() + + def setMaskColors(self, rgb, level=None): + """Set the masks color + + :param rgb: The rgb color + :param level: + The index of the mask for which we want to change the color. + If none set this color for all the masks + """ + if level is None: + self._overlayColors[:] = rgb + self._defaultColors[:] = False + else: + self._overlayColors[level] = rgb + self._defaultColors[level] = False + + self._updateColors() + + def getMaskColors(self): + """masks colors getter""" + return self._overlayColors + + def _updateColors(self, *args): + """Rebuild mask colormap when selected level or transparency change""" + self._setMaskColors(self.levelSpinBox.value(), + self.transparencySlider.value() / + self.transparencySlider.maximum()) + self._updatePlotMask() + self._updateInteractiveMode() + + def _pencilWidthChanged(self, width): + + old = self.pencilSpinBox.blockSignals(True) + try: + self.pencilSpinBox.setValue(width) + finally: + self.pencilSpinBox.blockSignals(old) + + old = self.pencilSlider.blockSignals(True) + try: + self.pencilSlider.setValue(width) + finally: + self.pencilSlider.blockSignals(old) + self._updateInteractiveMode() + + def _updateInteractiveMode(self): + """Update the current mode to the same if some cached data have to be + updated. It is the case for the color for example. + """ + if self._drawingMode == 'rectangle': + self._activeRectMode() + elif self._drawingMode == 'polygon': + self._activePolygonMode() + elif self._drawingMode == 'pencil': + self._activePencilMode() + + def _handleClearMask(self): + """Handle clear button clicked: reset current level mask""" + self._mask.clear(self.levelSpinBox.value()) + self._mask.commit() + + def _handleInvertMask(self): + """Invert the current mask level selection.""" + self._mask.invert(self.levelSpinBox.value()) + self._mask.commit() + + # Handle drawing tools UI events + + def _interactiveModeChanged(self, source): + """Handle plot interactive mode changed: + + If changed from elsewhere, disable drawing tool + """ + if source is not self: + # Do not trigger browseAction to avoid to call + # self.plot.setInteractiveMode + self.browseAction.setChecked(True) + self._releaseDrawingMode() + + def _releaseDrawingMode(self): + """Release the drawing mode if is was used""" + if self._drawingMode is None: + return + self.plot.sigPlotSignal.disconnect(self._plotDrawEvent) + self._drawingMode = None + + def _activeBrowseMode(self): + """Handle browse action mode triggered by user. + + Set plot interactive mode only when + the user is triggering the browse action. + """ + self._releaseDrawingMode() + self.plot.setInteractiveMode('zoom', source=self) + self._updateDrawingModeWidgets() + + def _activeRectMode(self): + """Handle rect action mode triggering""" + self._releaseDrawingMode() + self._drawingMode = 'rectangle' + self.plot.sigPlotSignal.connect(self._plotDrawEvent) + color = self.getCurrentMaskColor() + self.plot.setInteractiveMode( + 'draw', shape='rectangle', source=self, color=color) + self._updateDrawingModeWidgets() + + def _activePolygonMode(self): + """Handle polygon action mode triggering""" + self._releaseDrawingMode() + self._drawingMode = 'polygon' + self.plot.sigPlotSignal.connect(self._plotDrawEvent) + color = self.getCurrentMaskColor() + self.plot.setInteractiveMode('draw', shape='polygon', source=self, color=color) + self._updateDrawingModeWidgets() + + def _activePencilMode(self): + """Handle pencil action mode triggering""" + self._releaseDrawingMode() + self._drawingMode = 'pencil' + self.plot.sigPlotSignal.connect(self._plotDrawEvent) + color = self.getCurrentMaskColor() + width = self.pencilSpinBox.value() + self.plot.setInteractiveMode( + 'draw', shape='pencil', source=self, color=color, width=width) + self._updateDrawingModeWidgets() + + def _updateDrawingModeWidgets(self): + self.pencilSetting.setVisible(self._drawingMode == 'pencil') + + # Handle plot drawing events + + def _isMasking(self): + """Returns true if the tool is used for masking, else it is used for + unmasking. + + :rtype: bool""" + # First draw event, use current modifiers for all draw sequence + doMask = (self.maskStateGroup.checkedId() == 1) + if qt.QApplication.keyboardModifiers() & qt.Qt.ControlModifier: + doMask = not doMask + return doMask + + # Handle threshold UI events + def _belowThresholdActionTriggered(self, triggered): + if triggered: + self.minLineEdit.setEnabled(True) + self.maxLineEdit.setEnabled(False) + self.applyMaskBtn.setEnabled(True) + + def _betweenThresholdActionTriggered(self, triggered): + if triggered: + self.minLineEdit.setEnabled(True) + self.maxLineEdit.setEnabled(True) + self.applyMaskBtn.setEnabled(True) + + def _aboveThresholdActionTriggered(self, triggered): + if triggered: + self.minLineEdit.setEnabled(False) + self.maxLineEdit.setEnabled(True) + self.applyMaskBtn.setEnabled(True) + + def _thresholdActionGroupTriggered(self, triggeredAction): + """Threshold action group listener.""" + if triggeredAction.isChecked(): + # Uncheck other actions + for action in self.thresholdActionGroup.actions(): + if action is not triggeredAction and action.isChecked(): + action.setChecked(False) + else: + # Disable min/max edit + self.minLineEdit.setEnabled(False) + self.maxLineEdit.setEnabled(False) + self.applyMaskBtn.setEnabled(False) + + def _maskBtnClicked(self): + if self.belowThresholdAction.isChecked(): + if self.minLineEdit.text(): + self._mask.updateBelowThreshold(self.levelSpinBox.value(), + float(self.minLineEdit.text())) + self._mask.commit() + + elif self.betweenThresholdAction.isChecked(): + if self.minLineEdit.text() and self.maxLineEdit.text(): + min_ = float(self.minLineEdit.text()) + max_ = float(self.maxLineEdit.text()) + self._mask.updateBetweenThresholds(self.levelSpinBox.value(), + min_, max_) + self._mask.commit() + + elif self.aboveThresholdAction.isChecked(): + if self.maxLineEdit.text(): + max_ = float(self.maxLineEdit.text()) + self._mask.updateAboveThreshold(self.levelSpinBox.value(), + max_) + self._mask.commit() + + def _maskNotFiniteBtnClicked(self): + """Handle not finite mask button clicked: mask NaNs and inf""" + self._mask.updateNotFinite( + self.levelSpinBox.value()) + self._mask.commit() + + +class BaseMaskToolsDockWidget(qt.QDockWidget): + """Base class for :class:`MaskToolsWidget` and + :class:`ScatterMaskToolsWidget` + + For integration in a :class:`PlotWindow`. + + :param parent: See :class:`QDockWidget` + :paran str name: The title of this widget + """ + + sigMaskChanged = qt.Signal() + + def __init__(self, parent=None, name='Mask'): + super(BaseMaskToolsDockWidget, self).__init__(parent) + self.setWindowTitle(name) + + self.layout().setContentsMargins(0, 0, 0, 0) + self.dockLocationChanged.connect(self._dockLocationChanged) + self.topLevelChanged.connect(self._topLevelChanged) + + def _emitSigMaskChanged(self): + """Notify mask changes""" + # must be connected to self.widget().sigMaskChanged in child class + self.sigMaskChanged.emit() + + def getSelectionMask(self, copy=True): + """Get the current mask as a 2D array. + + :param bool copy: True (default) to get a copy of the mask. + If False, the returned array MUST not be modified. + :return: The array of the mask with dimension of the 'active' image. + If there is no active image, an empty array is returned. + :rtype: 2D numpy.ndarray of uint8 + """ + return self.widget().getSelectionMask(copy=copy) + + def setSelectionMask(self, mask, copy=True): + """Set the mask to a new array. + + :param numpy.ndarray mask: The array to use for the mask. + :type mask: numpy.ndarray of uint8 of dimension 2, C-contiguous. + Array of other types are converted. + :param bool copy: True (the default) to copy the array, + False to use it as is if possible. + :return: None if failed, shape of mask as 2-tuple if successful. + The mask can be cropped or padded to fit active image, + the returned shape is that of the active image. + """ + return self.widget().setSelectionMask(mask, copy=copy) + + def toggleViewAction(self): + """Returns a checkable action that shows or closes this widget. + + See :class:`QMainWindow`. + """ + action = super(BaseMaskToolsDockWidget, self).toggleViewAction() + action.setIcon(icons.getQIcon('image-mask')) + action.setToolTip("Display/hide mask tools") + return action + + def _dockLocationChanged(self, area): + if area in (qt.Qt.LeftDockWidgetArea, qt.Qt.RightDockWidgetArea): + direction = qt.QBoxLayout.TopToBottom + else: + direction = qt.QBoxLayout.LeftToRight + self.widget().setDirection(direction) + + def _topLevelChanged(self, topLevel): + if topLevel: + self.widget().setDirection(qt.QBoxLayout.LeftToRight) + self.resize(self.widget().minimumSize()) + self.adjustSize() + + def showEvent(self, event): + """Make sure this widget is raised when it is shown + (when it is first created as a tab in PlotWindow or when it is shown + again after hiding). + """ + self.raise_() diff --git a/silx/gui/plot/__init__.py b/silx/gui/plot/__init__.py new file mode 100644 index 0000000..06a24a7 --- /dev/null +++ b/silx/gui/plot/__init__.py @@ -0,0 +1,71 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This package provides a set of Qt widgets for plotting curves and images. + +The plotting API is inherited from the `PyMca <http://pymca.sourceforge.net/>`_ +plot API and is mostly compatible with it. + +Those widgets supports interaction (e.g., zoom, pan, selections). + +List of Qt widgets: + +.. currentmodule:: silx.gui.plot + +- :mod:`.PlotWidget`: A widget displaying a single plot. +- :mod:`.PlotWindow`: A :mod:`.PlotWidget` with a configurable set of tools. +- :class:`.Plot1D`: A widget with tools for curves. +- :class:`.Plot2D`: A widget with tools for images. +- :class:`.ImageView`: A widget with tools for images and a side histogram. +- :class:`.StackView`: A widget with tools for a stack of images. + +By default, those widget are using matplotlib_. +They can optionally use a faster OpenGL-based rendering (beta feature), +which is enabled by setting the ``backend`` argument to ``'gl'`` +when creating the widgets (See :class:`.Plot`). + +.. note:: + + This package depends on matplotlib_. + The OpenGL backend further depends on + `PyOpenGL <http://pyopengl.sourceforge.net/>`_ and OpenGL >= 2.1. + +.. _matplotlib: http://matplotlib.org/ +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "22/02/2016" + + +# First of all init matplotlib and set its backend +from .backends import _matplotlib # noqa + +from .PlotWidget import PlotWidget # noqa +from .PlotWindow import PlotWindow, Plot1D, Plot2D # noqa +from .ImageView import ImageView # noqa +from .StackView import StackView # noqa + +__all__ = ['ImageView', 'PlotWidget', 'PlotWindow', 'Plot1D', 'Plot2D', + 'StackView'] diff --git a/silx/gui/plot/_utils/__init__.py b/silx/gui/plot/_utils/__init__.py new file mode 100644 index 0000000..355bc02 --- /dev/null +++ b/silx/gui/plot/_utils/__init__.py @@ -0,0 +1,104 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Miscellaneous utility functions for the Plot""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "21/03/2017" + + +import numpy + +from .panzoom import FLOAT32_SAFE_MIN, FLOAT32_MINPOS, FLOAT32_SAFE_MAX +from .panzoom import applyZoomToPlot, applyPan + + +def clipColormapLogRange(colormap): + """Clip colormap vmin and vmax to 1, 10 if normalization is 'log' and vmin + or vmax <1 + + :param dict colormap: the colormap for which we want to clip vmin and vmax + """ + if colormap['normalization'] is 'log': + if colormap['vmin'] < 1. or colormap['vmax'] < 1.: + colormap['vmin'], colormap['vmax'] = 1., 10. + + +def addMarginsToLimits(margins, isXLog, isYLog, + xMin, xMax, yMin, yMax, y2Min=None, y2Max=None): + """Returns updated limits by extending them with margins. + + :param margins: The ratio of the margins to add or None for no margins. + :type margins: A 4-tuple of floats as + (xMinMargin, xMaxMargin, yMinMargin, yMaxMargin) + + :return: The updated limits + :rtype: tuple of 4 or 6 floats: Either (xMin, xMax, yMin, yMax) or + (xMin, xMax, yMin, yMax, y2Min, y2Max) if y2Min and y2Max + are provided. + """ + if margins is not None: + xMinMargin, xMaxMargin, yMinMargin, yMaxMargin = margins + + if not isXLog: + xRange = xMax - xMin + xMin -= xMinMargin * xRange + xMax += xMaxMargin * xRange + + elif xMin > 0. and xMax > 0.: # Log scale + # Do not apply margins if limits < 0 + xMinLog, xMaxLog = numpy.log10(xMin), numpy.log10(xMax) + xRangeLog = xMaxLog - xMinLog + xMin = pow(10., xMinLog - xMinMargin * xRangeLog) + xMax = pow(10., xMaxLog + xMaxMargin * xRangeLog) + + if not isYLog: + yRange = yMax - yMin + yMin -= yMinMargin * yRange + yMax += yMaxMargin * yRange + elif yMin > 0. and yMax > 0.: # Log scale + # Do not apply margins if limits < 0 + yMinLog, yMaxLog = numpy.log10(yMin), numpy.log10(yMax) + yRangeLog = yMaxLog - yMinLog + yMin = pow(10., yMinLog - yMinMargin * yRangeLog) + yMax = pow(10., yMaxLog + yMaxMargin * yRangeLog) + + if y2Min is not None and y2Max is not None: + if not isYLog: + yRange = y2Max - y2Min + y2Min -= yMinMargin * yRange + y2Max += yMaxMargin * yRange + elif y2Min > 0. and y2Max > 0.: # Log scale + # Do not apply margins if limits < 0 + yMinLog, yMaxLog = numpy.log10(y2Min), numpy.log10(y2Max) + yRangeLog = yMaxLog - yMinLog + y2Min = pow(10., yMinLog - yMinMargin * yRangeLog) + y2Max = pow(10., yMaxLog + yMaxMargin * yRangeLog) + + if y2Min is None or y2Max is None: + return xMin, xMax, yMin, yMax + else: + return xMin, xMax, yMin, yMax, y2Min, y2Max + diff --git a/silx/gui/plot/_utils/panzoom.py b/silx/gui/plot/_utils/panzoom.py new file mode 100644 index 0000000..bec31df --- /dev/null +++ b/silx/gui/plot/_utils/panzoom.py @@ -0,0 +1,156 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Functions to apply pan and zoom on a Plot""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "21/03/2017" + + +import math + +import numpy + + +# Float 32 info ############################################################### +# Using min/max value below limits of float32 +# so operation with such value (e.g., max - min) do not overflow + +FLOAT32_SAFE_MIN = -1e37 +FLOAT32_MINPOS = numpy.finfo(numpy.float32).tiny +FLOAT32_SAFE_MAX = 1e37 +# TODO double support + + +def scale1DRange(min_, max_, center, scale, isLog): + """Scale a 1D range given a scale factor and an center point. + + Keeps the values in a smaller range than float32. + + :param float min_: The current min value of the range. + :param float max_: The current max value of the range. + :param float center: The center of the zoom (i.e., invariant point). + :param float scale: The scale to use for zoom + :param bool isLog: Whether using log scale or not. + :return: The zoomed range. + :rtype: tuple of 2 floats: (min, max) + """ + if isLog: + # Min and center can be < 0 when + # autoscale is off and switch to log scale + # max_ < 0 should not happen + min_ = numpy.log10(min_) if min_ > 0. else FLOAT32_MINPOS + center = numpy.log10(center) if center > 0. else FLOAT32_MINPOS + max_ = numpy.log10(max_) if max_ > 0. else FLOAT32_MINPOS + + if min_ == max_: + return min_, max_ + + offset = (center - min_) / (max_ - min_) + range_ = (max_ - min_) / scale + newMin = center - offset * range_ + newMax = center + (1. - offset) * range_ + + if isLog: + # No overflow as exponent is log10 of a float32 + newMin = pow(10., newMin) + newMax = pow(10., newMax) + newMin = numpy.clip(newMin, FLOAT32_MINPOS, FLOAT32_SAFE_MAX) + newMax = numpy.clip(newMax, FLOAT32_MINPOS, FLOAT32_SAFE_MAX) + else: + newMin = numpy.clip(newMin, FLOAT32_SAFE_MIN, FLOAT32_SAFE_MAX) + newMax = numpy.clip(newMax, FLOAT32_SAFE_MIN, FLOAT32_SAFE_MAX) + return newMin, newMax + + +def applyZoomToPlot(plot, scaleF, center=None): + """Zoom in/out plot given a scale and a center point. + + :param plot: The plot on which to apply zoom. + :param float scaleF: Scale factor of zoom. + :param center: (x, y) coords in pixel coordinates of the zoom center. + :type center: 2-tuple of float + """ + xMin, xMax = plot.getGraphXLimits() + yMin, yMax = plot.getGraphYLimits() + + if center is None: + left, top, width, height = plot.getPlotBoundsInPixels() + cx, cy = left + width // 2, top + height // 2 + else: + cx, cy = center + + dataCenterPos = plot.pixelToData(cx, cy) + assert dataCenterPos is not None + + xMin, xMax = scale1DRange(xMin, xMax, dataCenterPos[0], scaleF, + plot.isXAxisLogarithmic()) + + yMin, yMax = scale1DRange(yMin, yMax, dataCenterPos[1], scaleF, + plot.isYAxisLogarithmic()) + + dataPos = plot.pixelToData(cx, cy, axis="right") + assert dataPos is not None + y2Center = dataPos[1] + y2Min, y2Max = plot.getGraphYLimits(axis="right") + y2Min, y2Max = scale1DRange(y2Min, y2Max, y2Center, scaleF, + plot.isYAxisLogarithmic()) + + plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max) + + +def applyPan(min_, max_, panFactor, isLog10): + """Returns a new range with applied panning. + + Moves the range according to panFactor. + If isLog10 is True, converts to log10 before moving. + + :param float min_: Min value of the data range to pan. + :param float max_: Max value of the data range to pan. + Must be >= min. + :param float panFactor: Signed proportion of the range to use for pan. + :param bool isLog10: True if log10 scale, False if linear scale. + :return: New min and max value with pan applied. + :rtype: 2-tuple of float. + """ + if isLog10 and min_ > 0.: + # Negative range and log scale can happen with matplotlib + logMin, logMax = math.log10(min_), math.log10(max_) + logOffset = panFactor * (logMax - logMin) + newMin = pow(10., logMin + logOffset) + newMax = pow(10., logMax + logOffset) + + # Takes care of out-of-range values + if newMin > 0. and newMax < float('inf'): + min_, max_ = newMin, newMax + + else: + offset = panFactor * (max_ - min_) + newMin, newMax = min_ + offset, max_ + offset + + # Takes care of out-of-range values + if newMin > - float('inf') and newMax < float('inf'): + min_, max_ = newMin, newMax + return min_, max_ diff --git a/silx/gui/plot/_utils/setup.py b/silx/gui/plot/_utils/setup.py new file mode 100644 index 0000000..0271745 --- /dev/null +++ b/silx/gui/plot/_utils/setup.py @@ -0,0 +1,42 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "21/03/2017" + + +from numpy.distutils.misc_util import Configuration + + +def configuration(parent_package='', top_path=None): + config = Configuration('_utils', parent_package, top_path) + config.add_subpackage('test') + return config + + +if __name__ == "__main__": + from numpy.distutils.core import setup + + setup(configuration=configuration) diff --git a/silx/gui/plot/_utils/test/__init__.py b/silx/gui/plot/_utils/test/__init__.py new file mode 100644 index 0000000..4a443ac --- /dev/null +++ b/silx/gui/plot/_utils/test/__init__.py @@ -0,0 +1,41 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "18/10/2016" + + +import unittest + +from .test_ticklayout import suite as test_ticklayout_suite + + +def suite(): + testsuite = unittest.TestSuite() + testsuite.addTest(test_ticklayout_suite()) + return testsuite diff --git a/silx/gui/plot/_utils/test/test_ticklayout.py b/silx/gui/plot/_utils/test/test_ticklayout.py new file mode 100644 index 0000000..8c67620 --- /dev/null +++ b/silx/gui/plot/_utils/test/test_ticklayout.py @@ -0,0 +1,78 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "18/10/2016" + + +import unittest + +from silx.test.utils import ParametricTestCase + +from silx.gui.plot._utils import ticklayout + + +class TestTickLayout(ParametricTestCase): + """Test ticks layout algorithms""" + + def testNiceNumbers(self): + """Minimalistic tests of :func:`niceNumbers`""" + tests = { # (vmin, vmax): ref_ticks + (0.5, 10.5): (0.0, 12.0, 2.0, 0), + (10000., 10000.5): (10000.0, 10000.5, 0.1, 1), + (0.001, 0.005): (0.001, 0.005, 0.001, 3) + } + + for (vmin, vmax), ref_ticks in tests.items(): + with self.subTest(vmin=vmin, vmax=vmax): + ticks = ticklayout.niceNumbers(vmin, vmax) + self.assertEqual(ticks, ref_ticks) + + def testNiceNumbersLog(self): + """Minimalistic tests of :func:`niceNumbersForLog10`""" + tests = { # (log10(min), log10(max): ref_ticks + (0., 3.): (0, 3, 1, 0), + (-3., 3): (-3, 3, 1, 0), + (-32., 0.): (-36, 0, 6, 0) + } + + for (vmin, vmax), ref_ticks in tests.items(): + with self.subTest(vmin=vmin, vmax=vmax): + ticks = ticklayout.niceNumbersForLog10(vmin, vmax) + self.assertEqual(ticks, ref_ticks) + + +def suite(): + testsuite = unittest.TestSuite() + testsuite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestTickLayout)) + return testsuite + + +if __name__ == '__main__': + unittest.main() diff --git a/silx/gui/plot/_utils/ticklayout.py b/silx/gui/plot/_utils/ticklayout.py new file mode 100644 index 0000000..5f4b636 --- /dev/null +++ b/silx/gui/plot/_utils/ticklayout.py @@ -0,0 +1,224 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module implements labels layout on graph axes.""" + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "18/10/2016" + + +import math + + +# utils ####################################################################### + +def numberOfDigits(tickSpacing): + """Returns the number of digits to display for text label. + + :param float tickSpacing: Step between ticks in data space. + :return: Number of digits to show for labels. + :rtype: int + """ + nfrac = int(-math.floor(math.log10(tickSpacing))) + if nfrac < 0: + nfrac = 0 + return nfrac + + +# Nice Numbers ################################################################ + +def _niceNum(value, isRound=False): + expvalue = math.floor(math.log10(value)) + frac = value/pow(10., expvalue) + if isRound: + if frac < 1.5: + nicefrac = 1. + elif frac < 3.: + nicefrac = 2. + elif frac < 7.: + nicefrac = 5. + else: + nicefrac = 10. + else: + if frac <= 1.: + nicefrac = 1. + elif frac <= 2.: + nicefrac = 2. + elif frac <= 5.: + nicefrac = 5. + else: + nicefrac = 10. + return nicefrac * pow(10., expvalue) + + +def niceNumbers(vMin, vMax, nTicks=5): + """Returns tick positions. + + This function implements graph labels layout using nice numbers + by Paul Heckbert from "Graphics Gems", Academic Press, 1990. + See `C code <http://tog.acm.org/resources/GraphicsGems/gems/Label.c>`_. + + :param float vMin: The min value on the axis + :param float vMax: The max value on the axis + :param int nTicks: The number of ticks to position + :returns: min, max, increment value of tick positions and + number of fractional digit to show + :rtype: tuple + """ + vrange = _niceNum(vMax - vMin, False) + spacing = _niceNum(vrange / nTicks, True) + graphmin = math.floor(vMin / spacing) * spacing + graphmax = math.ceil(vMax / spacing) * spacing + nfrac = numberOfDigits(spacing) + return graphmin, graphmax, spacing, nfrac + + +def _frange(start, stop, step): + """range for float (including stop).""" + assert step >= 0. + while start <= stop: + yield start + start += step + + +def ticks(vMin, vMax, nbTicks=5): + """Returns tick positions and labels using nice numbers algorithm. + + This enforces ticks to be within [vMin, vMax] range. + It returns at least 2 ticks. + + :param float vMin: The min value on the axis + :param float vMax: The max value on the axis + :param int nbTicks: The number of ticks to position + :returns: tick positions and corresponding text labels + :rtype: 2-tuple: list of float, list of string + """ + start, end, step, nfrac = niceNumbers(vMin, vMax, nbTicks) + positions = [t for t in _frange(start, end, step) if vMin <= t <= vMax] + + # Makes sure there is at least 2 ticks + if len(positions) < 2: + positions = [vMin, vMax] + nfrac = numberOfDigits(vMax - vMin) + + # Generate labels + format_ = '%g' if nfrac == 0 else '%.{}f'.format(nfrac) + labels = [format_ % tick for tick in positions] + return positions, labels + + +def niceNumbersAdaptative(vMin, vMax, axisLength, tickDensity): + """Returns tick positions using :func:`niceNumbers` and a + density of ticks. + + axisLength and tickDensity are based on the same unit (e.g., pixel). + + :param float vMin: The min value on the axis + :param float vMax: The max value on the axis + :param float axisLength: The length of the axis. + :param float tickDensity: The density of ticks along the axis. + :returns: min, max, increment value of tick positions and + number of fractional digit to show + :rtype: tuple + """ + # At least 2 ticks + nticks = max(2, int(round(tickDensity * axisLength))) + tickmin, tickmax, step, nfrac = niceNumbers(vMin, vMax, nticks) + + return tickmin, tickmax, step, nfrac + + +# Nice Numbers for log scale ################################################## + +def niceNumbersForLog10(minLog, maxLog, nTicks=5): + """Return tick positions for logarithmic scale + + :param float minLog: log10 of the min value on the axis + :param float maxLog: log10 of the max value on the axis + :param int nTicks: The number of ticks to position + :returns: log10 of min, max, increment value of tick positions and + number of fractional digit to show + :rtype: tuple of int + """ + graphminlog = math.floor(minLog) + graphmaxlog = math.ceil(maxLog) + rangelog = graphmaxlog - graphminlog + + if rangelog <= nTicks: + spacing = 1. + else: + spacing = math.floor(rangelog / nTicks) + + graphminlog = math.floor(graphminlog / spacing) * spacing + graphmaxlog = math.ceil(graphmaxlog / spacing) * spacing + + nfrac = numberOfDigits(spacing) + + return int(graphminlog), int(graphmaxlog), int(spacing), nfrac + + +def niceNumbersAdaptativeForLog10(vMin, vMax, axisLength, tickDensity): + """Returns tick positions using :func:`niceNumbers` and a + density of ticks. + + axisLength and tickDensity are based on the same unit (e.g., pixel). + + :param float vMin: The min value on the axis + :param float vMax: The max value on the axis + :param float axisLength: The length of the axis. + :param float tickDensity: The density of ticks along the axis. + :returns: log10 of min, max, increment value of tick positions and + number of fractional digit to show + :rtype: tuple + """ + # At least 2 ticks + nticks = max(2, int(round(tickDensity * axisLength))) + tickmin, tickmax, step, nfrac = niceNumbersForLog10(vMin, vMax, nticks) + + return tickmin, tickmax, step, nfrac + + +def computeLogSubTicks(ticks, lowBound, highBound): + """Return the sub ticks for the log scale for all given ticks if subtick + is in [lowBound, highBound] + + :param ticks: log10 of the ticks + :param lowBound: the lower boundary of ticks + :param highBound: the higher boundary of ticks + :return: all the sub ticks contained in ticks (log10) + """ + if len(ticks) < 1: + return [] + + res = [] + for logPos in ticks: + dataOrigPos = logPos + for index in range(2, 10): + dataPos = dataOrigPos * index + if lowBound <= dataPos <= highBound: + res.append(dataPos) + return res diff --git a/silx/gui/plot/backends/BackendBase.py b/silx/gui/plot/backends/BackendBase.py new file mode 100644 index 0000000..74f96af --- /dev/null +++ b/silx/gui/plot/backends/BackendBase.py @@ -0,0 +1,474 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +"""Base class for Plot backends. + +It documents the Plot backend API. + +This API is a simplified version of PyMca PlotBackend API. +""" + +__authors__ = ["V.A. Sole", "T. Vincent"] +__license__ = "MIT" +__date__ = "18/02/2016" + + +import weakref + + +# Names for setCursor +CURSOR_DEFAULT = 'default' +CURSOR_POINTING = 'pointing' +CURSOR_SIZE_HOR = 'size horizontal' +CURSOR_SIZE_VER = 'size vertical' +CURSOR_SIZE_ALL = 'size all' + + +class BackendBase(object): + """Class defining the API a backend of the Plot should provide.""" + + def __init__(self, plot, parent=None): + """Init. + + :param Plot plot: The Plot this backend is attached to + :param parent: The parent widget of the plot widget. + """ + self.__xLimits = 1., 100. + self.__yLimits = {'left': (1., 100.), 'right': (1., 100.)} + self.__yAxisInverted = False + self.__keepDataAspectRatio = False + # Store a weakref to get access to the plot state. + self._setPlot(plot) + + @property + def _plot(self): + """The plot this backend is attached to.""" + if self._plotRef is None: + raise RuntimeError('This backend is not attached to a Plot') + + plot = self._plotRef() + if plot is None: + raise RuntimeError('This backend is no more attached to a Plot') + return plot + + def _setPlot(self, plot): + """Allow to set plot after init. + + Use with caution, basically **immediately** after init. + """ + self._plotRef = weakref.ref(plot) + + # Add methods + + def addCurve(self, x, y, legend, + color, symbol, linewidth, linestyle, + yaxis, + xerror, yerror, z, selectable, + fill, alpha, symbolsize): + """Add a 1D curve given by x an y to the graph. + + :param numpy.ndarray x: The data corresponding to the x axis + :param numpy.ndarray y: The data corresponding to the y axis + :param str legend: The legend to be associated to the curve + :param color: color(s) to be used + :type color: string ("#RRGGBB") or (npoints, 4) unsigned byte array or + one of the predefined color names defined in Colors.py + :param str symbol: Symbol to be drawn at each (x, y) position:: + + - ' ' or '' no symbol + - 'o' circle + - '.' point + - ',' pixel + - '+' cross + - 'x' x-cross + - 'd' diamond + - 's' square + + :param float linewidth: The width of the curve in pixels + :param str linestyle: Type of line:: + + - ' ' or '' no line + - '-' solid line + - '--' dashed line + - '-.' dash-dot line + - ':' dotted line + + :param str yaxis: The Y axis this curve belongs to in: 'left', 'right' + :param xerror: Values with the uncertainties on the x values + :type xerror: numpy.ndarray or None + :param yerror: Values with the uncertainties on the y values + :type yerror: numpy.ndarray or None + :param int z: Layer on which to draw the cuve + :param bool selectable: indicate if the curve can be selected + :param bool fill: True to fill the curve, False otherwise + :param float alpha: Curve opacity, as a float in [0., 1.] + :param float symbolsize: Size of the symbol (if any) drawn + at each (x, y) position. + :returns: The handle used by the backend to univocally access the curve + """ + return legend + + def addImage(self, data, legend, + origin, scale, z, + selectable, draggable, + colormap, alpha): + """Add an image to the plot. + + :param numpy.ndarray data: (nrows, ncolumns) data or + (nrows, ncolumns, RGBA) ubyte array + :param str legend: The legend to be associated to the image + :param origin: (origin X, origin Y) of the data. + Default: (0., 0.) + :type origin: 2-tuple of float + :param scale: (scale X, scale Y) of the data. + Default: (1., 1.) + :type scale: 2-tuple of float + :param int z: Layer on which to draw the image + :param bool selectable: indicate if the image can be selected + :param bool draggable: indicate if the image can be moved + :param colormap: Dictionary describing the colormap to use. + Ignored if data is RGB(A). + :type colormap: dict or None + :param float alpha: Opacity of the image, as a float in range [0, 1]. + :returns: The handle used by the backend to univocally access the image + """ + return legend + + def addItem(self, x, y, legend, shape, color, fill, overlay, z): + """Add an item (i.e. a shape) to the plot. + + :param numpy.ndarray x: The X coords of the points of the shape + :param numpy.ndarray y: The Y coords of the points of the shape + :param str legend: The legend to be associated to the item + :param str shape: Type of item to be drawn in + hline, polygon, rectangle, vline, polylines + :param str color: Color of the item + :param bool fill: True to fill the shape + :param bool overlay: True if item is an overlay, False otherwise + :param int z: Layer on which to draw the item + :returns: The handle used by the backend to univocally access the item + """ + return legend + + def addMarker(self, x, y, legend, text, color, + selectable, draggable, + symbol, constraint, overlay): + """Add a point, vertical line or horizontal line marker to the plot. + + :param float x: Horizontal position of the marker in graph coordinates. + If None, the marker is a horizontal line. + :param float y: Vertical position of the marker in graph coordinates. + If None, the marker is a vertical line. + :param str legend: Legend associated to the marker + :param str text: Text associated to the marker (or None for no text) + :param str color: Color to be used for instance 'blue', 'b', '#FF0000' + :param bool selectable: indicate if the marker can be selected + :param bool draggable: indicate if the marker can be moved + :param str symbol: Symbol representing the marker. + Only relevant for point markers where X and Y are not None. + Value in: + + - 'o' circle + - '.' point + - ',' pixel + - '+' cross + - 'x' x-cross + - 'd' diamond + - 's' square + + :param constraint: A function filtering marker displacement by + dragging operations or None for no filter. + This function is called each time a marker is + moved. + This parameter is only used if draggable is True. + :type constraint: None or a callable that takes the coordinates of + the current cursor position in the plot as input + and that returns the filtered coordinates. + :param bool overlay: True if marker is an overlay (Default: False). + This allows for rendering optimization if this + marker is changed often. + :return: Handle used by the backend to univocally access the marker + """ + return legend + + # Remove methods + + def remove(self, item): + """Remove an existing item from the plot. + + :param item: A backend specific item handle returned by a add* method + """ + pass + + # Interaction methods + + def setGraphCursorShape(self, cursor): + """Set the cursor shape. + + To override in interactive backends. + + :param str cursor: Name of the cursor shape or None + """ + pass + + def setGraphCursor(self, flag, color, linewidth, linestyle): + """Toggle the display of a crosshair cursor and set its attributes. + + To override in interactive backends. + + :param bool flag: Toggle the display of a crosshair cursor. + :param color: The color to use for the crosshair. + :type color: A string (either a predefined color name in Colors.py + or "#RRGGBB")) or a 4 columns unsigned byte array. + :param int linewidth: The width of the lines of the crosshair. + :param linestyle: Type of line:: + + - ' ' no line + - '-' solid line + - '--' dashed line + - '-.' dash-dot line + - ':' dotted line + + :type linestyle: None or one of the predefined styles. + """ + pass + + def pickItems(self, x, y): + """Get a list of items at a pixel position. + + :param float x: The x pixel coord where to pick. + :param float y: The y pixel coord where to pick. + :return: All picked items from back to front. + One dict per item, + with 'kind' key in 'curve', 'marker', 'image'; + 'legend' key, the item legend. + and for curves, 'xdata' and 'ydata' keys storing picked + position on the curve. + :rtype: list of dict + """ + return [] + + # Update curve + + def setCurveColor(self, curve, color): + """Set the color of a curve. + + :param curve: The curve handle + :param str color: The color to use. + """ + pass + + # Misc. + + def getWidgetHandle(self): + """Return the widget this backend is drawing to.""" + return None + + def postRedisplay(self): + """Trigger a :meth:`Plot.replot`. + + Default implementation triggers a synchronous replot if plot is dirty. + This method should be overridden by the embedding widget in order to + provide an asynchronous call to replot in order to optimize the number + replot operations. + """ + # This method can be deferred and it might happen that plot has been + # destroyed in between, especially with unittests + + plot = self._plotRef() + if plot is not None and plot._getDirtyPlot(): + plot.replot() + + def replot(self): + """Redraw the plot.""" + pass + + def saveGraph(self, fileName, fileFormat, dpi): + """Save the graph to a file (or a StringIO) + + :param fileName: Destination + :type fileName: String or StringIO or BytesIO + :param str fileFormat: String specifying the format + :param int dpi: The resolution to use or None. + """ + pass + + # Graph labels + + def setGraphTitle(self, title): + """Set the main title of the plot. + + :param str title: Title associated to the plot + """ + pass + + def setGraphXLabel(self, label): + """Set the X axis label. + + :param str label: label associated to the plot bottom X axis + """ + pass + + def setGraphYLabel(self, label, axis): + """Set the left Y axis label. + + :param str label: label associated to the plot left Y axis + :param str axis: The axis for which to get the limits: left or right + """ + pass + + # Graph limits + + def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None): + """Set the limits of the X and Y axes at once. + + :param float xmin: minimum bottom axis value + :param float xmax: maximum bottom axis value + :param float ymin: minimum left axis value + :param float ymax: maximum left axis value + :param float y2min: minimum right axis value + :param float y2max: maximum right axis value + """ + self.__xLimits = xmin, xmax + self.__yLimits['left'] = ymin, ymax + if y2min is not None and y2max is not None: + self.__yLimits['right'] = y2min, y2max + + def getGraphXLimits(self): + """Get the graph X (bottom) limits. + + :return: Minimum and maximum values of the X axis + """ + return self.__xLimits + + def setGraphXLimits(self, xmin, xmax): + """Set the limits of X axis. + + :param float xmin: minimum bottom axis value + :param float xmax: maximum bottom axis value + """ + self.__xLimits = xmin, xmax + + def getGraphYLimits(self, axis): + """Get the graph Y (left) limits. + + :param str axis: The axis for which to get the limits: left or right + :return: Minimum and maximum values of the Y axis + """ + return self.__yLimits[axis] + + def setGraphYLimits(self, ymin, ymax, axis): + """Set the limits of the Y axis. + + :param float ymin: minimum left axis value + :param float ymax: maximum left axis value + :param str axis: The axis for which to get the limits: left or right + """ + self.__yLimits[axis] = ymin, ymax + + # Graph axes + + def setXAxisLogarithmic(self, flag): + """Set the X axis scale between linear and log. + + :param bool flag: If True, the bottom axis will use a log scale + """ + pass + + def setYAxisLogarithmic(self, flag): + """Set the Y axis scale between linear and log. + + :param bool flag: If True, the left axis will use a log scale + """ + pass + + def setYAxisInverted(self, flag): + """Invert the Y axis. + + :param bool flag: If True, put the vertical axis origin on the top + """ + self.__yAxisInverted = bool(flag) + + def isYAxisInverted(self): + """Return True if left Y axis is inverted, False otherwise.""" + return self.__yAxisInverted + + def isKeepDataAspectRatio(self): + """Returns whether the plot is keeping data aspect ratio or not.""" + return self.__keepDataAspectRatio + + def setKeepDataAspectRatio(self, flag): + """Set whether to keep data aspect ratio or not. + + :param flag: True to respect data aspect ratio + :type flag: Boolean, default True + """ + self.__keepDataAspectRatio = bool(flag) + + def setGraphGrid(self, which): + """Set grid. + + :param which: None to disable grid, 'major' for major grid, + 'both' for major and minor grid + """ + pass + + # Data <-> Pixel coordinates conversion + + def dataToPixel(self, x, y, axis): + """Convert a position in data space to a position in pixels + in the widget. + + :param float x: The X coordinate in data space. + :param float y: The Y coordinate in data space. + :param str axis: The Y axis to use for the conversion + ('left' or 'right'). + :returns: The corresponding position in pixels or + None if the data position is not in the displayed area. + :rtype: A tuple of 2 floats: (xPixel, yPixel) or None. + """ + raise NotImplementedError() + + def pixelToData(self, x, y, axis, check): + """Convert a position in pixels in the widget to a position in + the data space. + + :param float x: The X coordinate in pixels. + :param float y: The Y coordinate in pixels. + :param str axis: The Y axis to use for the conversion + ('left' or 'right'). + :param bool check: True to check if the coordinates are in the + plot area. + :returns: The corresponding position in data space or + None if the pixel position is not in the plot area. + :rtype: A tuple of 2 floats: (xData, yData) or None. + """ + raise NotImplementedError() + + def getPlotBoundsInPixels(self): + """Plot area bounds in widget coordinates in pixels. + + :return: bounds as a 4-tuple of int: (left, top, width, height) + """ + raise NotImplementedError() diff --git a/silx/gui/plot/backends/BackendMatplotlib.py b/silx/gui/plot/backends/BackendMatplotlib.py new file mode 100644 index 0000000..f9e60d5 --- /dev/null +++ b/silx/gui/plot/backends/BackendMatplotlib.py @@ -0,0 +1,821 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Matplotlib Plot backend.""" + +from __future__ import division + +__authors__ = ["V.A. Sole", "T. Vincent, H. Payno"] +__license__ = "MIT" +__date__ = "18/01/2017" + + +import logging + +import numpy + + +_logger = logging.getLogger(__name__) + + +from ... import qt + +from ._matplotlib import FigureCanvasQTAgg +import matplotlib +from matplotlib.container import Container +from matplotlib.figure import Figure +from matplotlib.patches import Rectangle, Polygon +from matplotlib.image import AxesImage +from matplotlib.backend_bases import MouseEvent +from matplotlib.lines import Line2D +from matplotlib.collections import PathCollection, LineCollection + +from .ModestImage import ModestImage +from . import BackendBase +from .. import Colors +from .._utils import FLOAT32_MINPOS + + +class BackendMatplotlib(BackendBase.BackendBase): + """Base class for Matplotlib backend without a FigureCanvas. + + For interactive on screen plot, see :class:`BackendMatplotlibQt`. + + See :class:`BackendBase.BackendBase` for public API documentation. + """ + + def __init__(self, plot, parent=None): + super(BackendMatplotlib, self).__init__(plot, parent) + + # matplotlib is handling keep aspect ratio at draw time + # When keep aspect ratio is on, and one changes the limits and + # ask them *before* next draw has been performed he will get the + # limits without applying keep aspect ratio. + # This attribute is used to ensure consistent values returned + # when getting the limits at the expense of a replot + self._dirtyLimits = True + + self.fig = Figure() + self.fig.set_facecolor("w") + + self.ax = self.fig.add_axes([.15, .15, .75, .75], label="left") + self.ax2 = self.ax.twinx() + self.ax2.set_label("right") + + # critical for picking!!!! + self.ax2.set_zorder(0) + self.ax2.set_autoscaley_on(True) + self.ax.set_zorder(1) + # this works but the figure color is left + if matplotlib.__version__[0] < '2': + self.ax.set_axis_bgcolor('none') + else: + self.ax.set_facecolor('none') + self.fig.sca(self.ax) + + self._overlays = set() + self._background = None + + self._colormaps = {} + + self._graphCursor = tuple() + self.matplotlibVersion = matplotlib.__version__ + + self.setGraphXLimits(0., 100.) + self.setGraphYLimits(0., 100., axis='right') + self.setGraphYLimits(0., 100., axis='left') + + self._enableAxis('right', False) + + # Add methods + + def addCurve(self, x, y, legend, + color, symbol, linewidth, linestyle, + yaxis, + xerror, yerror, z, selectable, + fill, alpha, symbolsize): + for parameter in (x, y, legend, color, symbol, linewidth, linestyle, + yaxis, z, selectable, fill, alpha, symbolsize): + assert parameter is not None + assert yaxis in ('left', 'right') + + if (len(color) == 4 and + type(color[3]) in [type(1), numpy.uint8, numpy.int8]): + color = numpy.array(color, dtype=numpy.float) / 255. + + if yaxis == "right": + axes = self.ax2 + self._enableAxis("right", True) + else: + axes = self.ax + + picker = 3 if selectable else None + + artists = [] # All the artists composing the curve + + # First add errorbars if any so they are behind the curve + if xerror is not None or yerror is not None: + if hasattr(color, 'dtype') and len(color) == len(x): + errorbarColor = 'k' + else: + errorbarColor = color + + # On Debian 7 at least, Nx1 array yerr does not seems supported + if (yerror is not None and yerror.ndim == 2 and + yerror.shape[1] == 1 and len(x) != 1): + yerror = numpy.ravel(yerror) + + errorbars = axes.errorbar(x, y, label=legend, + xerr=xerror, yerr=yerror, + linestyle=' ', color=errorbarColor) + artists += list(errorbars.get_children()) + + if hasattr(color, 'dtype') and len(color) == len(x): + # scatter plot + if color.dtype not in [numpy.float32, numpy.float]: + actualColor = color / 255. + else: + actualColor = color + + if linestyle not in ["", " ", None]: + # scatter plot with an actual line ... + # we need to assign a color ... + curveList = axes.plot(x, y, label=legend, + linestyle=linestyle, + color=actualColor[0], + linewidth=linewidth, + picker=picker, + marker=None) + artists += list(curveList) + + scatter = axes.scatter(x, y, + label=legend, + color=actualColor, + marker=symbol, + picker=picker, + s=symbolsize) + artists.append(scatter) + + if fill: + artists.append(axes.fill_between( + x, FLOAT32_MINPOS, y, facecolor=actualColor[0], linestyle='')) + + else: # Curve + curveList = axes.plot(x, y, + label=legend, + linestyle=linestyle, + color=color, + linewidth=linewidth, + marker=symbol, + picker=picker, + markersize=symbolsize) + artists += list(curveList) + + if fill: + artists.append( + axes.fill_between(x, FLOAT32_MINPOS, y, facecolor=color)) + + for artist in artists: + artist.set_zorder(z) + if alpha < 1: + artist.set_alpha(alpha) + + return Container(artists) + + def addImage(self, data, legend, + origin, scale, z, + selectable, draggable, + colormap, alpha): + # Non-uniform image + # http://wiki.scipy.org/Cookbook/Histograms + # Non-linear axes + # http://stackoverflow.com/questions/11488800/non-linear-axes-for-imshow-in-matplotlib + for parameter in (data, legend, origin, scale, z, + selectable, draggable): + assert parameter is not None + + origin = float(origin[0]), float(origin[1]) + scale = float(scale[0]), float(scale[1]) + height, width = data.shape[0:2] + + picker = (selectable or draggable) + + # Debian 7 specific support + # No transparent colormap with matplotlib < 1.2.0 + # Add support for transparent colormap for uint8 data with + # colormap with 256 colors, linear norm, [0, 255] range + if matplotlib.__version__ < '1.2.0': + if (len(data.shape) == 2 and colormap['name'] is None and + 'colors' in colormap): + colors = numpy.array(colormap['colors'], copy=False) + if (colors.shape[-1] == 4 and + not numpy.all(numpy.equal(colors[3], 255))): + # This is a transparent colormap + if (colors.shape == (256, 4) and + colormap['normalization'] == 'linear' and + not colormap['autoscale'] and + colormap['vmin'] == 0 and + colormap['vmax'] == 255 and + data.dtype == numpy.uint8): + # Supported case, convert data to RGBA + data = colors[data.reshape(-1)].reshape( + data.shape + (4,)) + else: + _logger.warning( + 'matplotlib %s does not support transparent ' + 'colormap.', matplotlib.__version__) + + if ((height * width) > 5.0e5 and + origin == (0., 0.) and scale == (1., 1.)): + imageClass = ModestImage + else: + imageClass = AxesImage + + # the normalization can be a source of time waste + # Two possibilities, we receive data or a ready to show image + if len(data.shape) == 3: # RGBA image + image = imageClass(self.ax, + label="__IMAGE__" + legend, + interpolation='nearest', + picker=picker, + zorder=z, + origin='lower') + + else: + # Convert colormap argument to matplotlib colormap + scalarMappable = Colors.getMPLScalarMappable(colormap, data) + + # try as data + image = imageClass(self.ax, + label="__IMAGE__" + legend, + interpolation='nearest', + cmap=scalarMappable.cmap, + picker=picker, + zorder=z, + norm=scalarMappable.norm, + origin='lower') + if alpha < 1: + image.set_alpha(alpha) + + # Set image extent + xmin = origin[0] + xmax = xmin + scale[0] * width + if scale[0] < 0.: + xmin, xmax = xmax, xmin + + ymin = origin[1] + ymax = ymin + scale[1] * height + if scale[1] < 0.: + ymin, ymax = ymax, ymin + + image.set_extent((xmin, xmax, ymin, ymax)) + + # Set image data + if scale[0] < 0. or scale[1] < 0.: + # For negative scale, step by -1 + xstep = 1 if scale[0] >= 0. else -1 + ystep = 1 if scale[1] >= 0. else -1 + data = data[::ystep, ::xstep] + + image.set_data(data) + + self.ax.add_artist(image) + + return image + + def addItem(self, x, y, legend, shape, color, fill, overlay, z): + xView = numpy.array(x, copy=False) + yView = numpy.array(y, copy=False) + + if shape == "line": + item = self.ax.plot(x, y, label=legend, color=color, + linestyle='-', marker=None)[0] + + elif shape == "hline": + if hasattr(y, "__len__"): + y = y[-1] + item = self.ax.axhline(y, label=legend, color=color) + + elif shape == "vline": + if hasattr(x, "__len__"): + x = x[-1] + item = self.ax.axvline(x, label=legend, color=color) + + elif shape == 'rectangle': + xMin = numpy.nanmin(xView) + xMax = numpy.nanmax(xView) + yMin = numpy.nanmin(yView) + yMax = numpy.nanmax(yView) + w = xMax - xMin + h = yMax - yMin + item = Rectangle(xy=(xMin, yMin), + width=w, + height=h, + fill=False, + color=color) + if fill: + item.set_hatch('.') + + self.ax.add_patch(item) + + elif shape in ('polygon', 'polylines'): + xView = xView.reshape(1, -1) + yView = yView.reshape(1, -1) + item = Polygon(numpy.vstack((xView, yView)).T, + closed=(shape == 'polygon'), + fill=False, + label=legend, + color=color) + if fill and shape == 'polygon': + item.set_hatch('/') + + self.ax.add_patch(item) + + else: + raise NotImplementedError("Unsupported item shape %s" % shape) + + item.set_zorder(z) + + if overlay: + item.set_animated(True) + self._overlays.add(item) + + return item + + def addMarker(self, x, y, legend, text, color, + selectable, draggable, + symbol, constraint, overlay): + legend = "__MARKER__" + legend + + if x is not None and y is not None: + line = self.ax.plot(x, y, label=legend, + linestyle=" ", + color=color, + marker=symbol, + markersize=10.)[-1] + + if text is not None: + xtmp, ytmp = self.ax.transData.transform_point((x, y)) + inv = self.ax.transData.inverted() + xtmp, ytmp = inv.transform_point((xtmp, ytmp)) + + if symbol is None: + valign = 'baseline' + else: + valign = 'top' + text = " " + text + + line._infoText = self.ax.text(x, ytmp, text, + color=color, + horizontalalignment='left', + verticalalignment=valign) + + elif x is not None: + line = self.ax.axvline(x, label=legend, color=color) + if text is not None: + text = " " + text + ymin, ymax = self.getGraphYLimits(axis='left') + delta = abs(ymax - ymin) + if ymin > ymax: + ymax = ymin + ymax -= 0.005 * delta + line._infoText = self.ax.text(x, ymax, text, + color=color, + horizontalalignment='left', + verticalalignment='top') + + elif y is not None: + line = self.ax.axhline(y, label=legend, color=color) + + if text is not None: + text = " " + text + xmin, xmax = self.getGraphXLimits() + delta = abs(xmax - xmin) + if xmin > xmax: + xmax = xmin + xmax -= 0.005 * delta + line._infoText = self.ax.text(xmax, y, text, + color=color, + horizontalalignment='right', + verticalalignment='top') + + else: + raise RuntimeError('A marker must at least have one coordinate') + + if selectable or draggable: + line.set_picker(5) + + if overlay: + line.set_animated(True) + self._overlays.add(line) + + return line + + # Remove methods + + def remove(self, item): + # Warning: It also needs to remove extra stuff if added as for markers + if hasattr(item, "_infoText"): # For markers text + item._infoText.remove() + item._infoText = None + self._overlays.discard(item) + item.remove() + + # Interaction methods + + def setGraphCursor(self, flag, color, linewidth, linestyle): + if flag: + lineh = self.ax.axhline( + self.ax.get_ybound()[0], visible=False, color=color, + linewidth=linewidth, linestyle=linestyle) + lineh.set_animated(True) + + linev = self.ax.axvline( + self.ax.get_xbound()[0], visible=False, color=color, + linewidth=linewidth, linestyle=linestyle) + linev.set_animated(True) + + self._graphCursor = lineh, linev + else: + if self._graphCursor is not None: + lineh, linev = self._graphCursor + lineh.remove() + linev.remove() + self._graphCursor = tuple() + + # Active curve + + def setCurveColor(self, curve, color): + # Store Line2D and PathCollection + for artist in curve.get_children(): + if isinstance(artist, (Line2D, LineCollection)): + artist.set_color(color) + elif isinstance(artist, PathCollection): + artist.set_facecolors(color) + artist.set_edgecolors(color) + else: + _logger.warning( + 'setActiveCurve ignoring artist %s', str(artist)) + + # Misc. + + def getWidgetHandle(self): + return self.fig.canvas + + def _enableAxis(self, axis, flag=True): + """Show/hide Y axis + + :param str axis: Axis name: 'left' or 'right' + :param bool flag: Default, True + """ + assert axis in ('right', 'left') + axes = self.ax2 if axis == 'right' else self.ax + axes.get_yaxis().set_visible(flag) + + def replot(self): + """Do not perform rendering. + + Override in subclass to actually draw something. + """ + # TODO images, markers? scatter plot? move in remove? + # Right Y axis only support curve for now + # Hide right Y axis if no line is present + self._dirtyLimits = False + if not self.ax2.lines: + self._enableAxis('right', False) + + def saveGraph(self, fileName, fileFormat, dpi): + # fileName can be also a StringIO or file instance + if dpi is not None: + self.fig.savefig(fileName, format=fileFormat, dpi=dpi) + else: + self.fig.savefig(fileName, format=fileFormat) + self._plot._setDirtyPlot() + + # Graph labels + + def setGraphTitle(self, title): + self.ax.set_title(title) + + def setGraphXLabel(self, label): + self.ax.set_xlabel(label) + + def setGraphYLabel(self, label, axis): + axes = self.ax if axis == 'left' else self.ax2 + axes.set_ylabel(label) + + # Graph limits + + def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None): + # Let matplotlib taking care of keep aspect ratio if any + self._dirtyLimits = True + self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax)) + + if y2min is not None and y2max is not None: + if not self.isYAxisInverted(): + self.ax2.set_ylim(min(y2min, y2max), max(y2min, y2max)) + else: + self.ax2.set_ylim(max(y2min, y2max), min(y2min, y2max)) + + if not self.isYAxisInverted(): + self.ax.set_ylim(min(ymin, ymax), max(ymin, ymax)) + else: + self.ax.set_ylim(max(ymin, ymax), min(ymin, ymax)) + + def getGraphXLimits(self): + if self._dirtyLimits and self.isKeepDataAspectRatio(): + self.replot() # makes sure we get the right limits + return self.ax.get_xbound() + + def setGraphXLimits(self, xmin, xmax): + self._dirtyLimits = True + self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax)) + + def getGraphYLimits(self, axis): + assert axis in ('left', 'right') + ax = self.ax2 if axis == 'right' else self.ax + + if not ax.get_visible(): + return None + + if self._dirtyLimits and self.isKeepDataAspectRatio(): + self.replot() # makes sure we get the right limits + + return ax.get_ybound() + + def setGraphYLimits(self, ymin, ymax, axis): + ax = self.ax2 if axis == 'right' else self.ax + if ymax < ymin: + ymin, ymax = ymax, ymin + self._dirtyLimits = True + + if self.isKeepDataAspectRatio(): + # matplotlib keeps limits of shared axis when keeping aspect ratio + # So x limits are kept when changing y limits.... + # Change x limits first by taking into account aspect ratio + # and then change y limits.. so matplotlib does not need + # to make change (to y) to keep aspect ratio + xmin, xmax = ax.get_xbound() + curYMin, curYMax = ax.get_ybound() + + newXRange = (xmax - xmin) * (ymax - ymin) / (curYMax - curYMin) + xcenter = 0.5 * (xmin + xmax) + ax.set_xlim(xcenter - 0.5 * newXRange, xcenter + 0.5 * newXRange) + + if not self.isYAxisInverted(): + ax.set_ylim(ymin, ymax) + else: + ax.set_ylim(ymax, ymin) + + # Graph axes + + def setXAxisLogarithmic(self, flag): + self.ax2.set_xscale('log' if flag else 'linear') + self.ax.set_xscale('log' if flag else 'linear') + + def setYAxisLogarithmic(self, flag): + self.ax2.set_yscale('log' if flag else 'linear') + self.ax.set_yscale('log' if flag else 'linear') + + def setYAxisInverted(self, flag): + if self.ax.yaxis_inverted() != bool(flag): + self.ax.invert_yaxis() + + def isYAxisInverted(self): + return self.ax.yaxis_inverted() + + def isKeepDataAspectRatio(self): + return self.ax.get_aspect() in (1.0, 'equal') + + def setKeepDataAspectRatio(self, flag): + self.ax.set_aspect(1.0 if flag else 'auto') + self.ax2.set_aspect(1.0 if flag else 'auto') + + def setGraphGrid(self, which): + self.ax.grid(False, which='both') # Disable all grid first + if which is not None: + self.ax.grid(True, which=which) + + # Data <-> Pixel coordinates conversion + + def dataToPixel(self, x, y, axis): + ax = self.ax2 if axis == "right" else self.ax + + pixels = ax.transData.transform_point((x, y)) + xPixel, yPixel = pixels.T + return xPixel, yPixel + + def pixelToData(self, x, y, axis, check): + ax = self.ax2 if axis == "right" else self.ax + + inv = ax.transData.inverted() + x, y = inv.transform_point((x, y)) + + if check: + xmin, xmax = self.getGraphXLimits() + ymin, ymax = self.getGraphYLimits(axis=axis) + + if x > xmax or x < xmin or y > ymax or y < ymin: + return None # (x, y) is out of plot area + + return x, y + + def getPlotBoundsInPixels(self): + bbox = self.ax.get_window_extent().transformed( + self.fig.dpi_scale_trans.inverted()) + dpi = self.fig.dpi + # Warning this is not returning int... + return (bbox.bounds[0] * dpi, bbox.bounds[1] * dpi, + bbox.bounds[2] * dpi, bbox.bounds[3] * dpi) + + +class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib): + """QWidget matplotlib backend using a QtAgg canvas. + + It adds fast overlay drawing and mouse event management. + """ + + _sigPostRedisplay = qt.Signal() + """Signal handling automatic asynchronous replot""" + + def __init__(self, plot, parent=None): + self._insideResizeEventMethod = False + + BackendMatplotlib.__init__(self, plot, parent) + FigureCanvasQTAgg.__init__(self, self.fig) + self.setParent(parent) + + FigureCanvasQTAgg.setSizePolicy( + self, qt.QSizePolicy.Expanding, qt.QSizePolicy.Expanding) + FigureCanvasQTAgg.updateGeometry(self) + + # Make postRedisplay asynchronous using Qt signal + self._sigPostRedisplay.connect( + super(BackendMatplotlibQt, self).postRedisplay, + qt.Qt.QueuedConnection) + + self._picked = None + + self.mpl_connect('button_press_event', self._onMousePress) + self.mpl_connect('button_release_event', self._onMouseRelease) + self.mpl_connect('motion_notify_event', self._onMouseMove) + self.mpl_connect('scroll_event', self._onMouseWheel) + + def postRedisplay(self): + self._sigPostRedisplay.emit() + + # Mouse event forwarding + + _MPL_TO_PLOT_BUTTONS = {1: 'left', 2: 'middle', 3: 'right'} + + def _onMousePress(self, event): + self._plot.onMousePress( + event.x, event.y, self._MPL_TO_PLOT_BUTTONS[event.button]) + + def _onMouseMove(self, event): + if self._graphCursor: + lineh, linev = self._graphCursor + if event.inaxes != self.ax and lineh.get_visible(): + lineh.set_visible(False) + linev.set_visible(False) + self._plot._setDirtyPlot(overlayOnly=True) + else: + linev.set_visible(True) + linev.set_xdata((event.xdata, event.xdata)) + lineh.set_visible(True) + lineh.set_ydata((event.ydata, event.ydata)) + self._plot._setDirtyPlot(overlayOnly=True) + # onMouseMove must trigger replot if dirty flag is raised + + self._plot.onMouseMove(event.x, event.y) + + def _onMouseRelease(self, event): + self._plot.onMouseRelease( + event.x, event.y, self._MPL_TO_PLOT_BUTTONS[event.button]) + + def _onMouseWheel(self, event): + self._plot.onMouseWheel(event.x, event.y, event.step) + + def leaveEvent(self, event): + """QWidget event handler""" + self._plot.onMouseLeaveWidget() + + # picking + + def _onPick(self, event): + # TODO not very nice and fragile, find a better way? + # Make a selection according to kind + if self._picked is None: + _logger.error('Internal picking error') + return + + label = event.artist.get_label() + if label.startswith('__MARKER__'): + self._picked.append({'kind': 'marker', 'legend': label[10:]}) + + elif label.startswith('__IMAGE__'): + self._picked.append({'kind': 'image', 'legend': label[9:]}) + + else: # it's a curve, item have no picker for now + if isinstance(event.artist, PathCollection): + data = event.artist.get_offsets()[event.ind, :] + xdata, ydata = data[:, 0], data[:, 1] + elif isinstance(event.artist, Line2D): + xdata = event.artist.get_xdata()[event.ind] + ydata = event.artist.get_ydata()[event.ind] + else: + _logger.info('Unsupported artist, ignored') + return + + self._picked.append({'kind': 'curve', 'legend': label, + 'xdata': xdata, 'ydata': ydata}) + + def pickItems(self, x, y): + self._picked = [] + + # Weird way to do an explicit picking: Simulate a button press event + mouseEvent = MouseEvent('button_press_event', self, x, y) + cid = self.mpl_connect('pick_event', self._onPick) + self.fig.pick(mouseEvent) + self.mpl_disconnect(cid) + picked = self._picked + self._picked = None + + return picked + + # replot control + + def resizeEvent(self, event): + self._insideResizeEventMethod = True + # Need to dirty the whole plot on resize. + self._plot._setDirtyPlot() + FigureCanvasQTAgg.resizeEvent(self, event) + self._insideResizeEventMethod = False + + def draw(self): + """Override canvas draw method to support faster draw of overlays.""" + if self._plot._getDirtyPlot(): # Need a full redraw + FigureCanvasQTAgg.draw(self) + self._background = None # Any saved background is dirty + + if (self._overlays or self._graphCursor or + self._plot._getDirtyPlot() == 'overlay'): + # There are overlays or crosshair, or they is just no more overlays + + # Specific case: called from resizeEvent: + # avoid store/restore background, just draw the overlay + if not self._insideResizeEventMethod: + if self._background is None: # First store the background + self._background = self.copy_from_bbox(self.fig.bbox) + + self.restore_region(self._background) + + # This assume that items are only on left/bottom Axes + for item in self._overlays: + self.ax.draw_artist(item) + + for item in self._graphCursor: + self.ax.draw_artist(item) + + self.blit(self.fig.bbox) + + def replot(self): + BackendMatplotlib.replot(self) + self.draw() + + # cursor + + _QT_CURSORS = { + None: qt.Qt.ArrowCursor, + BackendBase.CURSOR_DEFAULT: qt.Qt.ArrowCursor, + BackendBase.CURSOR_POINTING: qt.Qt.PointingHandCursor, + BackendBase.CURSOR_SIZE_HOR: qt.Qt.SizeHorCursor, + BackendBase.CURSOR_SIZE_VER: qt.Qt.SizeVerCursor, + BackendBase.CURSOR_SIZE_ALL: qt.Qt.SizeAllCursor, + } + + def setGraphCursorShape(self, cursor): + cursor = self._QT_CURSORS[cursor] + + FigureCanvasQTAgg.setCursor(self, qt.QCursor(cursor)) diff --git a/silx/gui/plot/backends/BackendOpenGL.py b/silx/gui/plot/backends/BackendOpenGL.py new file mode 100644 index 0000000..bc10eca --- /dev/null +++ b/silx/gui/plot/backends/BackendOpenGL.py @@ -0,0 +1,1631 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +"""OpenGL Plot backend.""" + +from __future__ import division + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "21/03/2017" + +from collections import OrderedDict, namedtuple +from ctypes import c_void_p +import logging + +import numpy + +from .._utils import FLOAT32_MINPOS +from . import BackendBase +from .. import Colors +from ... import qt + +from ..._glutils import gl +from ... import _glutils as glu +from .glutils import ( + GLPlotCurve2D, GLPlotColormap, GLPlotRGBAImage, GLPlotFrame2D, + mat4Ortho, mat4Identity, + LEFT, RIGHT, BOTTOM, TOP, + Text2D, Shape2D) +from .glutils.PlotImageFile import saveImageToFile + +_logger = logging.getLogger(__name__) + + +# TODO idea: BackendQtMixIn class to share code between mpl and gl +# TODO check if OpenGL is available +# TODO make an off-screen mesa backend + +# Bounds ###################################################################### + +class Range(namedtuple('Range', ('min_', 'max_'))): + """Describes a 1D range""" + + @property + def range_(self): + return self.max_ - self.min_ + + @property + def center(self): + return 0.5 * (self.min_ + self.max_) + + +class Bounds(object): + """Describes plot bounds with 2 y axis""" + + def __init__(self, xMin, xMax, yMin, yMax, y2Min, y2Max): + self._xAxis = Range(xMin, xMax) + self._yAxis = Range(yMin, yMax) + self._y2Axis = Range(y2Min, y2Max) + + def __repr__(self): + return "x: %s, y: %s, y2: %s" % (repr(self._xAxis), + repr(self._yAxis), + repr(self._y2Axis)) + + @property + def xAxis(self): + return self._xAxis + + @property + def yAxis(self): + return self._yAxis + + @property + def y2Axis(self): + return self._y2Axis + + +# Content ##################################################################### + +class PlotDataContent(object): + """Manage plot data content: images and curves. + + This class is only meant to work with _OpenGLPlotCanvas. + """ + + _PRIMITIVE_TYPES = 'curve', 'image' + + def __init__(self): + self._primitives = OrderedDict() # For images and curves + + def add(self, primitive): + """Add a curve or image to the content dictionary. + + This function generates the key in the dict from the primitive. + + :param primitive: The primitive to add. + :type primitive: Instance of GLPlotCurve2D, GLPlotColormap, + GLPlotRGBAImage. + """ + if isinstance(primitive, GLPlotCurve2D): + primitiveType = 'curve' + elif isinstance(primitive, (GLPlotColormap, GLPlotRGBAImage)): + primitiveType = 'image' + else: + raise RuntimeError('Unsupported object type: %s', primitive) + + key = primitiveType, primitive.info['legend'] + self._primitives[key] = primitive + + def get(self, primitiveType, legend): + """Get the corresponding primitive of given type with given legend. + + :param str primitiveType: Type of primitive ('curve' or 'image'). + :param str legend: The legend of the primitive to retrieve. + :return: The corresponding curve or None if no such curve. + """ + assert primitiveType in self._PRIMITIVE_TYPES + return self._primitives.get((primitiveType, legend)) + + def pop(self, primitiveType, key): + """Pop the corresponding curve or return None if no such curve. + + :param str primitiveType: + :param str key: + :return: + """ + assert primitiveType in self._PRIMITIVE_TYPES + return self._primitives.pop((primitiveType, key), None) + + def zOrderedPrimitives(self, reverse=False): + """List of primitives sorted according to their z order. + + It is a stable sort (as sorted): + Original order is preserved when key is the same. + + :param bool reverse: Ascending (True, default) or descending (False). + """ + return sorted(self._primitives.values(), + key=lambda primitive: primitive.info['zOrder'], + reverse=reverse) + + def primitives(self): + """Iterator over all primitives.""" + return self._primitives.values() + + def primitiveKeys(self, primitiveType): + """Iterator over primitives of a specific type.""" + assert primitiveType in self._PRIMITIVE_TYPES + for type_, key in self._primitives.keys(): + if type_ == primitiveType: + yield key + + def getBounds(self, xPositive=False, yPositive=False): + """Bounds of the data. + + Can return strictly positive bounds (for log scale). + In this case, curves are clipped to their smaller positive value + and images with negative min are ignored. + + :param bool xPositive: True to get strictly positive range. + :param bool yPositive: True to get strictly positive range. + :return: The range of data for x, y and y2, or default (1., 100.) + if no range found for one dimension. + :rtype: Bounds + """ + xMin, yMin, y2Min = float('inf'), float('inf'), float('inf') + xMax = 0. if xPositive else -float('inf') + if yPositive: + yMax, y2Max = 0., 0. + else: + yMax, y2Max = -float('inf'), -float('inf') + + for item in self._primitives.values(): + # To support curve <= 0. and log and bypass images: + # If positive only, uses x|yMinPos if available + # and bypass other data with negative min bounds + if xPositive: + itemXMin = getattr(item, 'xMinPos', item.xMin) + if itemXMin is None or itemXMin < FLOAT32_MINPOS: + continue + else: + itemXMin = item.xMin + + if yPositive: + itemYMin = getattr(item, 'yMinPos', item.yMin) + if itemYMin is None or itemYMin < FLOAT32_MINPOS: + continue + else: + itemYMin = item.yMin + + if itemXMin < xMin: + xMin = itemXMin + if item.xMax > xMax: + xMax = item.xMax + + if item.info.get('yAxis') == 'right': + if itemYMin < y2Min: + y2Min = itemYMin + if item.yMax > y2Max: + y2Max = item.yMax + else: + if itemYMin < yMin: + yMin = itemYMin + if item.yMax > yMax: + yMax = item.yMax + + # One of the limit has not been updated, return default range + if xMin >= xMax: + xMin, xMax = 1., 100. + if yMin >= yMax: + yMin, yMax = 1., 100. + if y2Min >= y2Max: + y2Min, y2Max = 1., 100. + + return Bounds(xMin, xMax, yMin, yMax, y2Min, y2Max) + + +# shaders ##################################################################### + +_baseVertShd = """ + attribute vec2 position; + uniform mat4 matrix; + uniform bvec2 isLog; + + const float oneOverLog10 = 0.43429448190325176; + + void main(void) { + vec2 posTransformed = position; + if (isLog.x) { + posTransformed.x = oneOverLog10 * log(position.x); + } + if (isLog.y) { + posTransformed.y = oneOverLog10 * log(position.y); + } + gl_Position = matrix * vec4(posTransformed, 0.0, 1.0); + } + """ + +_baseFragShd = """ + uniform vec4 color; + uniform int hatchStep; + uniform float tickLen; + + void main(void) { + if (tickLen != 0.) { + if (mod((gl_FragCoord.x + gl_FragCoord.y) / tickLen, 2.) < 1.) { + gl_FragColor = color; + } else { + discard; + } + } else if (hatchStep == 0 || + mod(gl_FragCoord.x - gl_FragCoord.y, float(hatchStep)) == 0.) { + gl_FragColor = color; + } else { + discard; + } + } + """ + +_texVertShd = """ + attribute vec2 position; + attribute vec2 texCoords; + uniform mat4 matrix; + + varying vec2 coords; + + void main(void) { + gl_Position = matrix * vec4(position, 0.0, 1.0); + coords = texCoords; + } + """ + +_texFragShd = """ + uniform sampler2D tex; + + varying vec2 coords; + + void main(void) { + gl_FragColor = texture2D(tex, coords); + } + """ + + +# BackendOpenGL ############################################################### + +_current_context = None + + +def _getContext(): + assert _current_context is not None + return _current_context + + +class BackendOpenGL(BackendBase.BackendBase, qt.QGLWidget): + """OpenGL-based Plot backend. + + WARNINGS: + Unless stated otherwise, this API is NOT thread-safe and MUST be + called from the main thread. + When numpy arrays are passed as arguments to the API (through + :func:`addCurve` and :func:`addImage`), they are copied only if + required. + So, the caller should not modify these arrays afterwards. + """ + + _sigPostRedisplay = qt.Signal() + """Signal handling automatic asynchronous replot""" + + def __init__(self, plot, parent=None): + qt.QGLWidget.__init__(self, parent) + BackendBase.BackendBase.__init__(self, plot, parent) + + self.matScreenProj = mat4Identity() + + self._progBase = glu.Program( + _baseVertShd, _baseFragShd, attrib0='position') + self._progTex = glu.Program( + _texVertShd, _texFragShd, attrib0='position') + self._plotFBOs = {} + + self._keepDataAspectRatio = False + + self._devicePixelRatio = 1.0 + + self._crosshairCursor = None + self._mousePosInPixels = None + + self._markers = OrderedDict() + self._items = OrderedDict() + self._plotContent = PlotDataContent() # For images and curves + self._selectionAreas = OrderedDict() + self._glGarbageCollector = [] + + self._plotFrame = GLPlotFrame2D( + margins={'left': 100, 'right': 50, 'top': 50, 'bottom': 50}) + + # Make postRedisplay asynchronous using Qt signal + self._sigPostRedisplay.connect( + super(BackendOpenGL, self).postRedisplay, + qt.Qt.QueuedConnection) + + # TODO is this needed? move it Plot? + self.setGraphXLimits(0., 100.) + self.setGraphYLimits(0., 100., axis='right') + self.setGraphYLimits(0., 100., axis='left') + + self.setAutoFillBackground(False) + self.setMouseTracking(True) + + # QWidget + + _MOUSE_BTNS = {1: 'left', 2: 'right', 4: 'middle'} + + def sizeHint(self): + return qt.QSize(8 * 80, 6 * 80) # Mimic MatplotlibBackend + + def mousePressEvent(self, event): + xPixel = event.x() * self._devicePixelRatio + yPixel = event.y() * self._devicePixelRatio + btn = self._MOUSE_BTNS[event.button()] + self._plot.onMousePress(xPixel, yPixel, btn) + event.accept() + + def mouseMoveEvent(self, event): + xPixel = event.x() * self._devicePixelRatio + yPixel = event.y() * self._devicePixelRatio + + # Handle crosshair + inXPixel, inYPixel = self._mouseInPlotArea(xPixel, yPixel) + isCursorInPlot = inXPixel == xPixel and inYPixel == yPixel + + previousMousePosInPixels = self._mousePosInPixels + self._mousePosInPixels = (xPixel, yPixel) if isCursorInPlot else None + if (self._crosshairCursor is not None and + previousMousePosInPixels != self._crosshairCursor): + # Avoid replot when cursor remains outside plot area + self._plot._setDirtyPlot(overlayOnly=True) + + self._plot.onMouseMove(xPixel, yPixel) + event.accept() + + def mouseReleaseEvent(self, event): + xPixel = event.x() * self._devicePixelRatio + yPixel = event.y() * self._devicePixelRatio + + btn = self._MOUSE_BTNS[event.button()] + self._plot.onMouseRelease(xPixel, yPixel, btn) + event.accept() + + def wheelEvent(self, event): + xPixel = event.x() * self._devicePixelRatio + yPixel = event.y() * self._devicePixelRatio + + if hasattr(event, 'angleDelta'): # Qt 5 + delta = event.angleDelta().y() + else: # Qt 4 support + delta = event.delta() + angleInDegrees = delta / 8. + self._plot.onMouseWheel(xPixel, yPixel, angleInDegrees) + event.accept() + + def leaveEvent(self, _): + self._plot.onMouseLeaveWidget() + + # QGLWidget API + + @staticmethod + def _setBlendFuncGL(): + # glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA) + gl.glBlendFuncSeparate(gl.GL_SRC_ALPHA, + gl.GL_ONE_MINUS_SRC_ALPHA, + gl.GL_ONE, + gl.GL_ONE) + + def initializeGL(self): + gl.testGL() + + gl.glClearColor(1., 1., 1., 1.) + gl.glClearStencil(0) + + gl.glEnable(gl.GL_BLEND) + self._setBlendFuncGL() + + # For lines + gl.glHint(gl.GL_LINE_SMOOTH_HINT, gl.GL_NICEST) + + # For points + gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2 + gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2 + # gl.glEnable(gl.GL_PROGRAM_POINT_SIZE) + + def _paintDirectGL(self): + self._renderPlotAreaGL() + self._plotFrame.render() + self._renderMarkersGL() + self._renderOverlayGL() + + def _paintFBOGL(self): + context = glu.getGLContext() + plotFBOTex = self._plotFBOs.get(context) + if (self._plot._getDirtyPlot() or self._plotFrame.isDirty or + plotFBOTex is None): + self._plotVertices = numpy.array(((-1., -1., 0., 0.), + (1., -1., 1., 0.), + (-1., 1., 0., 1.), + (1., 1., 1., 1.)), + dtype=numpy.float32) + if plotFBOTex is None or \ + plotFBOTex.shape[1] != self._plotFrame.size[0] or \ + plotFBOTex.shape[0] != self._plotFrame.size[1]: + if plotFBOTex is not None: + plotFBOTex.discard() + plotFBOTex = glu.FramebufferTexture( + gl.GL_RGBA, + shape=(self._plotFrame.size[1], + self._plotFrame.size[0]), + minFilter=gl.GL_NEAREST, + magFilter=gl.GL_NEAREST, + wrap=(gl.GL_CLAMP_TO_EDGE, + gl.GL_CLAMP_TO_EDGE)) + self._plotFBOs[context] = plotFBOTex + + with plotFBOTex: + gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT) + self._renderPlotAreaGL() + self._plotFrame.render() + + # Render plot in screen coords + gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1]) + + self._progTex.use() + texUnit = 0 + + gl.glUniform1i(self._progTex.uniforms['tex'], texUnit) + gl.glUniformMatrix4fv(self._progTex.uniforms['matrix'], 1, gl.GL_TRUE, + mat4Identity()) + + stride = self._plotVertices.shape[-1] * self._plotVertices.itemsize + gl.glEnableVertexAttribArray(self._progTex.attributes['position']) + gl.glVertexAttribPointer(self._progTex.attributes['position'], + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + stride, self._plotVertices) + + texCoordsPtr = c_void_p(self._plotVertices.ctypes.data + + 2 * self._plotVertices.itemsize) # Better way? + gl.glEnableVertexAttribArray(self._progTex.attributes['texCoords']) + gl.glVertexAttribPointer(self._progTex.attributes['texCoords'], + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + stride, texCoordsPtr) + + with plotFBOTex.texture: + gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self._plotVertices)) + + self._renderMarkersGL() + self._renderOverlayGL() + + def paintGL(self): + global _current_context + _current_context = self.context() + + glu.setGLContextGetter(_getContext) + + if hasattr(self, 'windowHandle'): # Qt 5 + devicePixelRatio = self.windowHandle().devicePixelRatio() + if devicePixelRatio != self._devicePixelRatio: + self._devicePixelRatio = devicePixelRatio + self.resizeGL(int(self.width() * devicePixelRatio), + int(self.height() * devicePixelRatio)) + + # Release OpenGL resources + for item in self._glGarbageCollector: + item.discard() + self._glGarbageCollector = [] + + gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT) + + # Check if window is large enough + plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:] + if plotWidth <= 2 or plotHeight <= 2: + return + + # self._paintDirectGL() + self._paintFBOGL() + + glu.setGLContextGetter() + _current_context = None + + def _nonOrthoAxesLineMarkerPrimitives(self, marker, pixelOffset): + """Generates the vertices and label for a line marker. + + :param dict marker: Description of a line marker + :param int pixelOffset: Offset of text from borders in pixels + :return: Line vertices and Text label or None + :rtype: 2-tuple (2x2 numpy.array of float, Text2D) + """ + label, vertices = None, None + + xCoord, yCoord = marker['x'], marker['y'] + assert xCoord is None or yCoord is None # Specific to line markers + + # Get plot corners in data coords + plotLeft, plotTop, plotWidth, plotHeight = self.getPlotBoundsInPixels() + + corners = [(plotLeft, plotTop), + (plotLeft, plotTop + plotHeight), + (plotLeft + plotWidth, plotTop + plotHeight), + (plotLeft + plotWidth, plotTop)] + corners = numpy.array([self.pixelToData(x, y, axis='left', check=False) + for (x, y) in corners]) + + borders = { + 'right': (corners[3], corners[2]), + 'top': (corners[0], corners[3]), + 'bottom': (corners[2], corners[1]), + 'left': (corners[1], corners[0]) + } + + textLayouts = { # align, valign, offsets + 'right': (RIGHT, BOTTOM, (-1., -1.)), + 'top': (LEFT, TOP, (1., 1.)), + 'bottom': (LEFT, BOTTOM, (1., -1.)), + 'left': (LEFT, BOTTOM, (1., -1.)) + } + + if xCoord is None: # Horizontal line in data space + if marker['text'] is not None: + # Find intersection of hline with borders in data + # Order is important as it stops at first intersection + for border_name in ('right', 'top', 'bottom', 'left'): + (x0, y0), (x1, y1) = borders[border_name] + + if min(y0, y1) <= yCoord < max(y0, y1): + xIntersect = (yCoord - y0) * (x1 - x0) / (y1 - y0) + x0 + + # Add text label + pixelPos = self.dataToPixel( + xIntersect, yCoord, axis='left', check=False) + + align, valign, offsets = textLayouts[border_name] + + x = pixelPos[0] + offsets[0] * pixelOffset + y = pixelPos[1] + offsets[1] * pixelOffset + label = Text2D(marker['text'], x, y, + color=marker['color'], + bgColor=(1., 1., 1., 0.5), + align=align, valign=valign) + break # Stop at first intersection + + xMin, xMax = corners[:, 0].min(), corners[:, 0].max() + vertices = numpy.array( + ((xMin, yCoord), (xMax, yCoord)), dtype=numpy.float32) + + else: # yCoord is None: vertical line in data space + if marker['text'] is not None: + # Find intersection of hline with borders in data + # Order is important as it stops at first intersection + for border_name in ('top', 'bottom', 'right', 'left'): + (x0, y0), (x1, y1) = borders[border_name] + if min(x0, x1) <= xCoord < max(x0, x1): + yIntersect = (xCoord - x0) * (y1 - y0) / (x1 - x0) + y0 + + # Add text label + pixelPos = self.dataToPixel( + xCoord, yIntersect, axis='left', check=False) + + align, valign, offsets = textLayouts[border_name] + + x = pixelPos[0] + offsets[0] * pixelOffset + y = pixelPos[1] + offsets[1] * pixelOffset + label = Text2D(marker['text'], x, y, + color=marker['color'], + bgColor=(1., 1., 1., 0.5), + align=align, valign=valign) + break # Stop at first intersection + + yMin, yMax = corners[:, 1].min(), corners[:, 1].max() + vertices = numpy.array( + ((xCoord, yMin), (xCoord, yMax)), dtype=numpy.float32) + + return vertices, label + + def _renderMarkersGL(self): + if len(self._markers) == 0: + return + + plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:] + + isXLog = self._plotFrame.xAxis.isLog + isYLog = self._plotFrame.yAxis.isLog + + # Render in plot area + gl.glScissor(self._plotFrame.margins.left, + self._plotFrame.margins.bottom, + plotWidth, plotHeight) + gl.glEnable(gl.GL_SCISSOR_TEST) + + gl.glViewport(self._plotFrame.margins.left, + self._plotFrame.margins.bottom, + plotWidth, plotHeight) + + # Prepare vertical and horizontal markers rendering + self._progBase.use() + gl.glUniformMatrix4fv(self._progBase.uniforms['matrix'], 1, gl.GL_TRUE, + self._plotFrame.transformedDataProjMat) + gl.glUniform2i(self._progBase.uniforms['isLog'], isXLog, isYLog) + gl.glUniform1i(self._progBase.uniforms['hatchStep'], 0) + gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.) + posAttrib = self._progBase.attributes['position'] + + labels = [] + pixelOffset = 3 + + for marker in self._markers.values(): + xCoord, yCoord = marker['x'], marker['y'] + + if ((isXLog and xCoord is not None and + xCoord < FLOAT32_MINPOS) or + (isYLog and yCoord is not None and + yCoord < FLOAT32_MINPOS)): + # Do not render markers with negative coords on log axis + continue + + if xCoord is None or yCoord is None: + if not self.isDefaultBaseVectors(): # Non-orthogonal axes + vertices, label = self._nonOrthoAxesLineMarkerPrimitives( + marker, pixelOffset) + if label is not None: + labels.append(label) + + else: # Orthogonal axes + pixelPos = self.dataToPixel( + xCoord, yCoord, axis='left', check=False) + + if xCoord is None: # Horizontal line in data space + if marker['text'] is not None: + x = self._plotFrame.size[0] - \ + self._plotFrame.margins.right - pixelOffset + y = pixelPos[1] - pixelOffset + label = Text2D(marker['text'], x, y, + color=marker['color'], + bgColor=(1., 1., 1., 0.5), + align=RIGHT, valign=BOTTOM) + labels.append(label) + + xMin, xMax = self._plotFrame.dataRanges.x + vertices = numpy.array(((xMin, yCoord), + (xMax, yCoord)), + dtype=numpy.float32) + + else: # yCoord is None: vertical line in data space + if marker['text'] is not None: + x = pixelPos[0] + pixelOffset + y = self._plotFrame.margins.top + pixelOffset + label = Text2D(marker['text'], x, y, + color=marker['color'], + bgColor=(1., 1., 1., 0.5), + align=LEFT, valign=TOP) + labels.append(label) + + yMin, yMax = self._plotFrame.dataRanges.y + vertices = numpy.array(((xCoord, yMin), + (xCoord, yMax)), + dtype=numpy.float32) + + self._progBase.use() + + gl.glUniform4f(self._progBase.uniforms['color'], + *marker['color']) + + gl.glEnableVertexAttribArray(posAttrib) + gl.glVertexAttribPointer(posAttrib, + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + 0, vertices) + gl.glLineWidth(1) + gl.glDrawArrays(gl.GL_LINES, 0, len(vertices)) + + else: + pixelPos = self.dataToPixel( + xCoord, yCoord, axis='left', check=True) + if pixelPos is None: + # Do not render markers outside visible plot area + continue + + if marker['text'] is not None: + x = pixelPos[0] + pixelOffset + y = pixelPos[1] + pixelOffset + label = Text2D(marker['text'], x, y, + color=marker['color'], + bgColor=(1., 1., 1., 0.5), + align=LEFT, valign=TOP) + labels.append(label) + + # For now simple implementation: using a curve for each marker + # Should pack all markers to a single set of points + markerCurve = GLPlotCurve2D( + numpy.array((xCoord,), dtype=numpy.float32), + numpy.array((yCoord,), dtype=numpy.float32), + marker=marker['symbol'], + markerColor=marker['color'], + markerSize=11) + markerCurve.render(self._plotFrame.transformedDataProjMat, + isXLog, isYLog) + markerCurve.discard() + + gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1]) + + # Render marker labels + for label in labels: + label.render(self.matScreenProj) + + gl.glDisable(gl.GL_SCISSOR_TEST) + + def _renderOverlayGL(self): + # Render selection area and crosshair cursor + if self._selectionAreas or self._crosshairCursor is not None: + plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:] + + # Scissor to plot area + gl.glScissor(self._plotFrame.margins.left, + self._plotFrame.margins.bottom, + plotWidth, plotHeight) + gl.glEnable(gl.GL_SCISSOR_TEST) + + self._progBase.use() + gl.glUniform2i(self._progBase.uniforms['isLog'], + self._plotFrame.xAxis.isLog, + self._plotFrame.yAxis.isLog) + gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.) + posAttrib = self._progBase.attributes['position'] + matrixUnif = self._progBase.uniforms['matrix'] + colorUnif = self._progBase.uniforms['color'] + hatchStepUnif = self._progBase.uniforms['hatchStep'] + + # Render selection area in plot area + if self._selectionAreas: + gl.glViewport(self._plotFrame.margins.left, + self._plotFrame.margins.bottom, + plotWidth, plotHeight) + + gl.glUniformMatrix4fv(matrixUnif, 1, gl.GL_TRUE, + self._plotFrame.transformedDataProjMat) + + for shape in self._selectionAreas.values(): + if shape.isVideoInverted: + gl.glBlendFunc(gl.GL_ONE_MINUS_DST_COLOR, gl.GL_ZERO) + + shape.render(posAttrib, colorUnif, hatchStepUnif) + + if shape.isVideoInverted: + self._setBlendFuncGL() + + # Render crosshair cursor is screen frame but with scissor + if (self._crosshairCursor is not None and + self._mousePosInPixels is not None): + gl.glViewport( + 0, 0, self._plotFrame.size[0], self._plotFrame.size[1]) + + gl.glUniformMatrix4fv(matrixUnif, 1, gl.GL_TRUE, + self.matScreenProj) + + color, lineWidth = self._crosshairCursor + gl.glUniform4f(colorUnif, *color) + gl.glUniform1i(hatchStepUnif, 0) + + xPixel, yPixel = self._mousePosInPixels + xPixel, yPixel = xPixel + 0.5, yPixel + 0.5 + vertices = numpy.array(((0., yPixel), + (self._plotFrame.size[0], yPixel), + (xPixel, 0.), + (xPixel, self._plotFrame.size[1])), + dtype=numpy.float32) + + gl.glEnableVertexAttribArray(posAttrib) + gl.glVertexAttribPointer(posAttrib, + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + 0, vertices) + gl.glLineWidth(lineWidth) + gl.glDrawArrays(gl.GL_LINES, 0, len(vertices)) + + gl.glDisable(gl.GL_SCISSOR_TEST) + + def _renderPlotAreaGL(self): + plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:] + + self._plotFrame.renderGrid() + + gl.glScissor(self._plotFrame.margins.left, + self._plotFrame.margins.bottom, + plotWidth, plotHeight) + gl.glEnable(gl.GL_SCISSOR_TEST) + + # Matrix + trBounds = self._plotFrame.transformedDataRanges + if trBounds.x[0] == trBounds.x[1] or \ + trBounds.y[0] == trBounds.y[1]: + return + + isXLog = self._plotFrame.xAxis.isLog + isYLog = self._plotFrame.yAxis.isLog + + gl.glViewport(self._plotFrame.margins.left, + self._plotFrame.margins.bottom, + plotWidth, plotHeight) + + # Render images and curves + # sorted is stable: original order is preserved when key is the same + for item in self._plotContent.zOrderedPrimitives(): + if item.info.get('yAxis') == 'right': + item.render(self._plotFrame.transformedDataY2ProjMat, + isXLog, isYLog) + else: + item.render(self._plotFrame.transformedDataProjMat, + isXLog, isYLog) + + # Render Items + self._progBase.use() + gl.glUniformMatrix4fv(self._progBase.uniforms['matrix'], 1, gl.GL_TRUE, + self._plotFrame.transformedDataProjMat) + gl.glUniform2i(self._progBase.uniforms['isLog'], + self._plotFrame.xAxis.isLog, + self._plotFrame.yAxis.isLog) + gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.) + + for item in self._items.values(): + shape2D = item.get('_shape2D') + if shape2D is None: + shape2D = Shape2D(tuple(zip(item['x'], item['y'])), + fill=item['fill'], + fillColor=item['color'], + stroke=True, + strokeColor=item['color']) + item['_shape2D'] = shape2D + + if ((isXLog and shape2D.xMin < FLOAT32_MINPOS) or + (isYLog and shape2D.yMin < FLOAT32_MINPOS)): + # Ignore items <= 0. on log axes + continue + + posAttrib = self._progBase.attributes['position'] + colorUnif = self._progBase.uniforms['color'] + hatchStepUnif = self._progBase.uniforms['hatchStep'] + shape2D.render(posAttrib, colorUnif, hatchStepUnif) + + gl.glDisable(gl.GL_SCISSOR_TEST) + + def resizeGL(self, width, height): + if width == 0 or height == 0: # Do not resize + return + self._plotFrame.size = width, height + + self.matScreenProj = mat4Ortho(0, self._plotFrame.size[0], + self._plotFrame.size[1], 0, + 1, -1) + + (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = \ + self._plotFrame.dataRanges + self.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max) + + # Add methods + + def addCurve(self, x, y, legend, + color, symbol, linewidth, linestyle, + yaxis, + xerror, yerror, z, selectable, + fill, alpha, symbolsize): + for parameter in (x, y, legend, color, symbol, linewidth, linestyle, + yaxis, z, selectable, fill, symbolsize): + assert parameter is not None + assert yaxis in ('left', 'right') + + x = numpy.array(x, dtype=numpy.float32, copy=False, order='C') + y = numpy.array(y, dtype=numpy.float32, copy=False, order='C') + if xerror is not None: + xerror = numpy.array( + xerror, dtype=numpy.float32, copy=False, order='C') + if yerror is not None: + yerror = numpy.array( + yerror, dtype=numpy.float32, copy=False, order='C') + + # TODO check and improve this + if (len(color) == 4 and + type(color[3]) in [type(1), numpy.uint8, numpy.int8]): + color = numpy.array(color, dtype=numpy.float32) / 255. + + if isinstance(color, numpy.ndarray) and color.ndim == 2: + colorArray = color + color = None + else: + colorArray = None + color = Colors.rgba(color) + + if alpha < 1.: # Apply image transparency + if colorArray is not None and colorArray.shape[1] == 4: + # multiply alpha channel + colorArray[:, 3] = colorArray[:, 3] * alpha + if color is not None: + color = color[0], color[1], color[2], color[3] * alpha + + behaviors = set() + if selectable: + behaviors.add('selectable') + + curve = GLPlotCurve2D(x, y, colorArray, + xError=xerror, + yError=yerror, + lineStyle=linestyle, + lineColor=color, + lineWidth=linewidth, + marker=symbol, + markerColor=color, + markerSize=symbolsize, + fillColor=color if fill else None) + curve.info = { + 'legend': legend, + 'zOrder': z, + 'behaviors': behaviors, + 'yAxis': 'left' if yaxis is None else yaxis, + } + + if yaxis == "right": + self._plotFrame.isY2Axis = True + + self._plotContent.add(curve) + + return legend, 'curve' + + def addImage(self, data, legend, + origin, scale, z, + selectable, draggable, + colormap, alpha): + for parameter in (data, legend, origin, scale, z, + selectable, draggable): + assert parameter is not None + + behaviors = set() + if selectable: + behaviors.add('selectable') + if draggable: + behaviors.add('draggable') + + if data.ndim == 2: + # Ensure array is contiguous and eventually convert its type + if data.dtype in (numpy.float32, numpy.uint8, numpy.uint16): + data = numpy.array(data, copy=False, order='C') + else: + _logger.info( + 'addImage: Convert %s data to float32', str(data.dtype)) + data = numpy.array(data, dtype=numpy.float32, order='C') + + colormapIsLog = colormap['normalization'].startswith('log') + + if colormap['autoscale']: + cmapRange = None + else: + cmapRange = colormap['vmin'], colormap['vmax'] + assert cmapRange[0] <= cmapRange[1] + + # Retrieve colormap LUT from name and color array + colormapLut = Colors.applyColormapToData( + numpy.arange(256, dtype=numpy.uint8), + name=colormap['name'], + normalization='linear', + autoscale=False, + vmin=0, + vmax=255, + colors=colormap.get('colors')) + + image = GLPlotColormap(data, + origin, + scale, + colormapLut, + colormapIsLog, + cmapRange, + alpha) + image.info = { + 'legend': legend, + 'zOrder': z, + 'behaviors': behaviors + } + self._plotContent.add(image) + + elif len(data.shape) == 3: + # For RGB, RGBA data + assert data.shape[2] in (3, 4) + assert data.dtype in (numpy.float32, numpy.uint8) + + image = GLPlotRGBAImage(data, origin, scale, alpha) + + image.info = { + 'legend': legend, + 'zOrder': z, + 'behaviors': behaviors + } + + if self._plotFrame.xAxis.isLog and image.xMin <= 0.: + raise RuntimeError( + 'Cannot add image with X <= 0 with X axis log scale') + if self._plotFrame.yAxis.isLog and image.yMin <= 0.: + raise RuntimeError( + 'Cannot add image with Y <= 0 with Y axis log scale') + + self._plotContent.add(image) + + else: + raise RuntimeError("Unsupported data shape {0}".format(data.shape)) + + return legend, 'image' + + def addItem(self, x, y, legend, shape, color, fill, overlay, z): + # TODO handle overlay + if shape not in ('polygon', 'rectangle', 'line', 'vline', 'hline'): + raise NotImplementedError("Unsupported shape {0}".format(shape)) + + x = numpy.array(x, copy=False) + y = numpy.array(y, copy=False) + + if shape == 'rectangle': + xMin, xMax = x + x = numpy.array((xMin, xMin, xMax, xMax)) + yMin, yMax = y + y = numpy.array((yMin, yMax, yMax, yMin)) + + # TODO is this needed? + if self._plotFrame.xAxis.isLog and x.min() <= 0.: + raise RuntimeError( + 'Cannot add item with X <= 0 with X axis log scale') + if self._plotFrame.yAxis.isLog and y.min() <= 0.: + raise RuntimeError( + 'Cannot add item with Y <= 0 with Y axis log scale') + + self._items[legend] = { + 'shape': shape, + 'color': Colors.rgba(color), + 'fill': 'hatch' if fill else None, + 'x': x, + 'y': y + } + + return legend, 'item' + + def addMarker(self, x, y, legend, text, color, + selectable, draggable, + symbol, constraint, overlay): + # TODO handle overlay + + if symbol is None: + symbol = '+' + + behaviors = set() + if selectable: + behaviors.add('selectable') + if draggable: + behaviors.add('draggable') + + # Apply constraint to provided position + isConstraint = (draggable and constraint is not None and + x is not None and y is not None) + if isConstraint: + x, y = constraint(x, y) + + if x is not None and self._plotFrame.xAxis.isLog and x <= 0.: + raise RuntimeError( + 'Cannot add marker with X <= 0 with X axis log scale') + if y is not None and self._plotFrame.yAxis.isLog and y <= 0.: + raise RuntimeError( + 'Cannot add marker with Y <= 0 with Y axis log scale') + + self._markers[legend] = { + 'x': x, + 'y': y, + 'legend': legend, + 'text': text, + 'color': Colors.rgba(color), + 'behaviors': behaviors, + 'constraint': constraint if isConstraint else None, + 'symbol': symbol, + } + + return legend, 'marker' + + # Remove methods + + def remove(self, item): + legend, kind = item + + if kind == 'curve': + curve = self._plotContent.pop('curve', legend) + if curve is not None: + # Check if some curves remains on the right Y axis + y2AxisItems = (item for item in self._plotContent.primitives() + if item.info.get('yAxis', 'left') == 'right') + self._plotFrame.isY2Axis = next(y2AxisItems, None) is not None + + self._glGarbageCollector.append(curve) + + elif kind == 'image': + image = self._plotContent.pop('image', legend) + if image is not None: + self._glGarbageCollector.append(image) + + elif kind == 'marker': + self._markers.pop(legend, False) + + elif kind == 'item': + self._items.pop(legend, False) + + else: + _logger.error('Unsupported kind: %s', str(kind)) + + # Interaction methods + + _QT_CURSORS = { + None: qt.Qt.ArrowCursor, + BackendBase.CURSOR_DEFAULT: qt.Qt.ArrowCursor, + BackendBase.CURSOR_POINTING: qt.Qt.PointingHandCursor, + BackendBase.CURSOR_SIZE_HOR: qt.Qt.SizeHorCursor, + BackendBase.CURSOR_SIZE_VER: qt.Qt.SizeVerCursor, + BackendBase.CURSOR_SIZE_ALL: qt.Qt.SizeAllCursor, + } + + def setGraphCursorShape(self, cursor): + cursor = self._QT_CURSORS[cursor] + + super(BackendOpenGL, self).setCursor(qt.QCursor(cursor)) + + def setGraphCursor(self, flag, color, linewidth, linestyle): + if linestyle is not '-': + _logger.warning( + "BackendOpenGL.setGraphCursor linestyle parameter ignored") + + if flag: + color = Colors.rgba(color) + crosshairCursor = color, linewidth + else: + crosshairCursor = None + + if crosshairCursor != self._crosshairCursor: + self._crosshairCursor = crosshairCursor + + _PICK_OFFSET = 3 # Offset in pixel used for picking + + def _mouseInPlotArea(self, x, y): + xPlot = numpy.clip( + x, self._plotFrame.margins.left, + self._plotFrame.size[0] - self._plotFrame.margins.right - 1) + yPlot = numpy.clip( + y, self._plotFrame.margins.top, + self._plotFrame.size[1] - self._plotFrame.margins.bottom - 1) + return xPlot, yPlot + + def pickItems(self, x, y): + picked = [] + + dataPos = self.pixelToData(x, y, axis='left', check=True) + if dataPos is not None: + # Pick markers + for marker in reversed(list(self._markers.values())): + pixelPos = self.dataToPixel( + marker['x'], marker['y'], axis='left', check=False) + if pixelPos is None: # negative coord on a log axis + continue + + if marker['x'] is None: # Horizontal line + pt1 = self.pixelToData( + x, y - self._PICK_OFFSET, axis='left', check=False) + pt2 = self.pixelToData( + x, y + self._PICK_OFFSET, axis='left', check=False) + isPicked = (min(pt1[1], pt2[1]) <= marker['y'] <= + max(pt1[1], pt2[1])) + + elif marker['y'] is None: # Vertical line + pt1 = self.pixelToData( + x - self._PICK_OFFSET, y, axis='left', check=False) + pt2 = self.pixelToData( + x + self._PICK_OFFSET, y, axis='left', check=False) + isPicked = (min(pt1[0], pt2[0]) <= marker['x'] <= + max(pt1[0], pt2[0])) + + else: + isPicked = ( + numpy.fabs(x - pixelPos[0]) <= self._PICK_OFFSET and + numpy.fabs(y - pixelPos[1]) <= self._PICK_OFFSET) + + if isPicked: + picked.append(dict(kind='marker', + legend=marker['legend'])) + + # Pick image and curves + for item in self._plotContent.zOrderedPrimitives(reverse=True): + if isinstance(item, (GLPlotColormap, GLPlotRGBAImage)): + pickedPos = item.pick(*dataPos) + if pickedPos is not None: + picked.append(dict(kind='image', + legend=item.info['legend'])) + + elif isinstance(item, GLPlotCurve2D): + offset = self._PICK_OFFSET + if item.marker is not None: + offset = max(item.markerSize / 2., offset) + if item.lineStyle is not None: + offset = max(item.lineWidth / 2., offset) + + yAxis = item.info['yAxis'] + + inAreaPos = self._mouseInPlotArea(x - offset, y - offset) + dataPos = self.pixelToData(inAreaPos[0], inAreaPos[1], + axis=yAxis, check=True) + if dataPos is None: + continue + xPick0, yPick0 = dataPos + + inAreaPos = self._mouseInPlotArea(x + offset, y + offset) + dataPos = self.pixelToData(inAreaPos[0], inAreaPos[1], + axis=yAxis, check=True) + if dataPos is None: + continue + xPick1, yPick1 = dataPos + + if xPick0 < xPick1: + xPickMin, xPickMax = xPick0, xPick1 + else: + xPickMin, xPickMax = xPick1, xPick0 + + if yPick0 < yPick1: + yPickMin, yPickMax = yPick0, yPick1 + else: + yPickMin, yPickMax = yPick1, yPick0 + + pickedIndices = item.pick(xPickMin, yPickMin, + xPickMax, yPickMax) + if pickedIndices: + picked.append(dict(kind='curve', + legend=item.info['legend'], + xdata=item.xData[pickedIndices], + ydata=item.yData[pickedIndices])) + + return picked + + # Update curve + + def setCurveColor(self, curve, color): + pass # TODO + + # Misc. + + def getWidgetHandle(self): + return self + + def postRedisplay(self): + self._sigPostRedisplay.emit() + + def replot(self): + self.update() # async redraw + # self.repaint() # immediate redraw + + def saveGraph(self, fileName, fileFormat, dpi): + if dpi is not None: + _logger.warning("saveGraph ignores dpi parameter") + + if fileFormat not in ['png', 'ppm', 'svg', 'tiff']: + raise NotImplementedError('Unsupported format: %s' % fileFormat) + + self.makeCurrent() + + data = numpy.empty( + (self._plotFrame.size[1], self._plotFrame.size[0], 3), + dtype=numpy.uint8, order='C') + + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0) + gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1) + gl.glReadPixels(0, 0, self._plotFrame.size[0], self._plotFrame.size[1], + gl.GL_RGB, gl.GL_UNSIGNED_BYTE, data) + + # glReadPixels gives bottom to top, + # while images are stored as top to bottom + data = numpy.flipud(data) + + # fileName is either a file-like object or a str + saveImageToFile(data, fileName, fileFormat) + + # Graph labels + + def setGraphTitle(self, title): + self._plotFrame.title = title + + def setGraphXLabel(self, label): + self._plotFrame.xAxis.title = label + + def setGraphYLabel(self, label, axis): + if axis == 'left': + self._plotFrame.yAxis.title = label + else: # right axis + if label: + _logger.warning('Right axis label not implemented') + + # Non orthogonal axes + + def setBaseVectors(self, x=(1., 0.), y=(0., 1.)): + """Set base vectors. + + Useful for non-orthogonal axes. + If an axis is in log scale, skew is applied to log transformed values. + + Base vector does not work well with log axes, to investi + """ + if x != (1., 0.) and y != (0., 1.): + if self._plotFrame.xAxis.isLog: + _logger.warning("setBaseVectors disables X axis logarithmic.") + self.setXAxisLogarithmic(False) + if self._plotFrame.yAxis.isLog: + _logger.warning("setBaseVectors disables Y axis logarithmic.") + self.setYAxisLogarithmic(False) + + if self.isKeepDataAspectRatio(): + _logger.warning("setBaseVectors disables keepDataAspectRatio.") + self.keepDataAspectRatio(False) + + self._plotFrame.baseVectors = x, y + + def getBaseVectors(self): + return self._plotFrame.baseVectors + + def isDefaultBaseVectors(self): + return self._plotFrame.baseVectors == \ + self._plotFrame.DEFAULT_BASE_VECTORS + + # Graph limits + + def _setDataRanges(self, xlim=None, ylim=None, y2lim=None): + """Set the visible range of data in the plot frame. + + This clips the ranges to possible values (takes care of float32 + range + positive range for log). + This also takes care of non-orthogonal axes. + + This should be moved to PlotFrame. + """ + # Update axes range with a clipped range if too wide + self._plotFrame.setDataRanges(xlim, ylim, y2lim) + + if not self.isDefaultBaseVectors(): + # Update axes range with axes bounds in data coords + plotLeft, plotTop, plotWidth, plotHeight = \ + self.getPlotBoundsInPixels() + + self._plotFrame.xAxis.dataRange = sorted([ + self.pixelToData(x, y, axis='left', check=False)[0] + for (x, y) in ((plotLeft, plotTop + plotHeight), + (plotLeft + plotWidth, plotTop + plotHeight))]) + + self._plotFrame.yAxis.dataRange = sorted([ + self.pixelToData(x, y, axis='left', check=False)[1] + for (x, y) in ((plotLeft, plotTop + plotHeight), + (plotLeft, plotTop))]) + + self._plotFrame.y2Axis.dataRange = sorted([ + self.pixelToData(x, y, axis='right', check=False)[1] + for (x, y) in ((plotLeft + plotWidth, plotTop + plotHeight), + (plotLeft + plotWidth, plotTop))]) + + def _ensureAspectRatio(self, keepDim=None): + """Update plot bounds in order to keep aspect ratio. + + Warning: keepDim on right Y axis is not implemented ! + + :param str keepDim: The dimension to maintain: 'x', 'y' or None. + If None (the default), the dimension with the largest range. + """ + plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:] + if plotWidth <= 2 or plotHeight <= 2: + return + + if keepDim is None: + dataBounds = self._plotContent.getBounds( + self._plotFrame.xAxis.isLog, self._plotFrame.yAxis.isLog) + if dataBounds.yAxis.range_ != 0.: + dataRatio = dataBounds.xAxis.range_ + dataRatio /= float(dataBounds.yAxis.range_) + + plotRatio = plotWidth / float(plotHeight) # Test != 0 before + + keepDim = 'x' if dataRatio > plotRatio else 'y' + else: # Limit case + keepDim = 'x' + + (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = \ + self._plotFrame.dataRanges + if keepDim == 'y': + dataW = (yMax - yMin) * plotWidth / float(plotHeight) + xCenter = 0.5 * (xMin + xMax) + xMin = xCenter - 0.5 * dataW + xMax = xCenter + 0.5 * dataW + elif keepDim == 'x': + dataH = (xMax - xMin) * plotHeight / float(plotWidth) + yCenter = 0.5 * (yMin + yMax) + yMin = yCenter - 0.5 * dataH + yMax = yCenter + 0.5 * dataH + y2Center = 0.5 * (y2Min + y2Max) + y2Min = y2Center - 0.5 * dataH + y2Max = y2Center + 0.5 * dataH + else: + raise RuntimeError('Unsupported dimension to keep: %s' % keepDim) + + # Update plot frame bounds + self._setDataRanges(xlim=(xMin, xMax), + ylim=(yMin, yMax), + y2lim=(y2Min, y2Max)) + + def _setPlotBounds(self, xRange=None, yRange=None, y2Range=None, + keepDim=None): + # Update axes range with a clipped range if too wide + self._setDataRanges(xlim=xRange, + ylim=yRange, + y2lim=y2Range) + + # Keep data aspect ratio + if self.isKeepDataAspectRatio(): + self._ensureAspectRatio(keepDim) + + def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None): + assert xmin < xmax + assert ymin < ymax + + if y2min is None or y2max is None: + y2Range = None + else: + assert y2min < y2max + y2Range = y2min, y2max + self._setPlotBounds((xmin, xmax), (ymin, ymax), y2Range) + + def getGraphXLimits(self): + return self._plotFrame.dataRanges.x + + def setGraphXLimits(self, xmin, xmax): + assert xmin < xmax + self._setPlotBounds(xRange=(xmin, xmax), keepDim='x') + + def getGraphYLimits(self, axis): + assert axis in ("left", "right") + if axis == "left": + return self._plotFrame.dataRanges.y + else: + return self._plotFrame.dataRanges.y2 + + def setGraphYLimits(self, ymin, ymax, axis): + assert ymin < ymax + assert axis in ("left", "right") + + if axis == "left": + self._setPlotBounds(yRange=(ymin, ymax), keepDim='y') + else: + self._setPlotBounds(y2Range=(ymin, ymax), keepDim='y') + + # Graph axes + + def setXAxisLogarithmic(self, flag): + if flag != self._plotFrame.xAxis.isLog: + if flag and self._keepDataAspectRatio: + _logger.warning( + "KeepDataAspectRatio is ignored with log axes") + + if flag and not self.isDefaultBaseVectors(): + _logger.warning( + "setXAxisLogarithmic ignored because baseVectors are set") + return + + self._plotFrame.xAxis.isLog = flag + + def setYAxisLogarithmic(self, flag): + if (flag != self._plotFrame.yAxis.isLog or + flag != self._plotFrame.y2Axis.isLog): + if flag and self._keepDataAspectRatio: + _logger.warning( + "KeepDataAspectRatio is ignored with log axes") + + if flag and not self.isDefaultBaseVectors(): + _logger.warning( + "setYAxisLogarithmic ignored because baseVectors are set") + return + + self._plotFrame.yAxis.isLog = flag + self._plotFrame.y2Axis.isLog = flag + + def setYAxisInverted(self, flag): + if flag != self._plotFrame.isYAxisInverted: + self._plotFrame.isYAxisInverted = flag + + def isYAxisInverted(self): + return self._plotFrame.isYAxisInverted + + def isKeepDataAspectRatio(self): + if self._plotFrame.xAxis.isLog or self._plotFrame.yAxis.isLog: + return False + else: + return self._keepDataAspectRatio + + def setKeepDataAspectRatio(self, flag): + if flag and (self._plotFrame.xAxis.isLog or + self._plotFrame.yAxis.isLog): + _logger.warning("KeepDataAspectRatio is ignored with log axes") + if flag and not self.isDefaultBaseVectors(): + _logger.warning( + "keepDataAspectRatio ignored because baseVectors are set") + + self._keepDataAspectRatio = flag + + def setGraphGrid(self, which): + assert which in (None, 'major', 'both') + self._plotFrame.grid = which is not None # TODO True grid support + + # Data <-> Pixel coordinates conversion + + def dataToPixel(self, x, y, axis, check=False): + assert axis in ('left', 'right') + + if x is None or y is None: + dataBounds = self._plotContent.getBounds( + self._plotFrame.xAxis.isLog, self._plotFrame.yAxis.isLog) + + if x is None: + x = dataBounds.xAxis.center + + if y is None: + if axis == 'left': + y = dataBounds.yAxis.center + else: + y = dataBounds.y2Axis.center + + result = self._plotFrame.dataToPixel(x, y, axis) + + if check and result is not None: + xPixel, yPixel = result + width, height = self._plotFrame.size + if (xPixel < self._plotFrame.margins.left or + xPixel > (width - self._plotFrame.margins.right) or + yPixel < self._plotFrame.margins.top or + yPixel > height - self._plotFrame.margins.bottom): + return None # (x, y) is out of plot area + + return result + + def pixelToData(self, x, y, axis, check): + assert axis in ("left", "right") + + if x is None: + x = self._plotFrame.size[0] / 2. + if y is None: + y = self._plotFrame.size[1] / 2. + + if check and (x < self._plotFrame.margins.left or + x > (self._plotFrame.size[0] - + self._plotFrame.margins.right) or + y < self._plotFrame.margins.top or + y > (self._plotFrame.size[1] - + self._plotFrame.margins.bottom)): + return None # (x, y) is out of plot area + + return self._plotFrame.pixelToData(x, y, axis) + + def getPlotBoundsInPixels(self): + return self._plotFrame.plotOrigin + self._plotFrame.plotSize diff --git a/silx/gui/plot/backends/ModestImage.py b/silx/gui/plot/backends/ModestImage.py new file mode 100644 index 0000000..93fba5a --- /dev/null +++ b/silx/gui/plot/backends/ModestImage.py @@ -0,0 +1,174 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +"""Matplotlib computationally modest image class.""" + +__authors__ = ["V.A. Sole", "T. Vincent"] +__license__ = "MIT" +__date__ = "16/02/2016" + + +import numpy + +from matplotlib import cbook +from matplotlib.image import AxesImage + + +class ModestImage(AxesImage): + """Computationally modest image class. + +Customization of https://github.com/ChrisBeaumont/ModestImage to allow +extent support. + +ModestImage is an extension of the Matplotlib AxesImage class +better suited for the interactive display of larger images. Before +drawing, ModestImage resamples the data array based on the screen +resolution and view window. This has very little affect on the +appearance of the image, but can substantially cut down on +computation since calculations of unresolved or clipped pixels +are skipped. + +The interface of ModestImage is the same as AxesImage. However, it +does not currently support setting the 'extent' property. There +may also be weird coordinate warping operations for images that +I'm not aware of. Don't expect those to work either. +""" + def __init__(self, *args, **kwargs): + self._full_res = None + self._sx, self._sy = None, None + self._bounds = (None, None, None, None) + self._origExtent = None + super(ModestImage, self).__init__(*args, **kwargs) + if 'extent' in kwargs and kwargs['extent'] is not None: + self.set_extent(kwargs['extent']) + + def set_extent(self, extent): + super(ModestImage, self).set_extent(extent) + if self._origExtent is None: + self._origExtent = self.get_extent() + + def get_image_extent(self): + """Returns the extent of the whole image. + + get_extent returns the extent of the drawn area and not of the full + image. + + :return: Bounds of the image (x0, x1, y0, y1). + :rtype: Tuple of 4 floats. + """ + if self._origExtent is not None: + return self._origExtent + else: + return self.get_extent() + + def set_data(self, A): + """ + Set the image array + + ACCEPTS: numpy/PIL Image A + """ + + self._full_res = A + self._A = A + + if (self._A.dtype != numpy.uint8 and + not numpy.can_cast(self._A.dtype, numpy.float)): + raise TypeError("Image data can not convert to float") + + if (self._A.ndim not in (2, 3) or + (self._A.ndim == 3 and self._A.shape[-1] not in (3, 4))): + raise TypeError("Invalid dimensions for image data") + + self._imcache = None + self._rgbacache = None + self._oldxslice = None + self._oldyslice = None + self._sx, self._sy = None, None + + def get_array(self): + """Override to return the full-resolution array""" + return self._full_res + + def _scale_to_res(self): + """ Change self._A and _extent to render an image whose +resolution is matched to the eventual rendering.""" + # extent has to be set BEFORE set_data + if self._origExtent is None: + if self.origin == "upper": + self._origExtent = (0, self._full_res.shape[1], + self._full_res.shape[0], 0) + else: + self._origExtent = (0, self._full_res.shape[1], + 0, self._full_res.shape[0]) + + if self.origin == "upper": + origXMin, origXMax, origYMax, origYMin = self._origExtent[0:4] + else: + origXMin, origXMax, origYMin, origYMax = self._origExtent[0:4] + ax = self.axes + ext = ax.transAxes.transform([1, 1]) - ax.transAxes.transform([0, 0]) + xlim, ylim = ax.get_xlim(), ax.get_ylim() + xlim = max(xlim[0], origXMin), min(xlim[1], origXMax) + if ylim[0] > ylim[1]: + ylim = max(ylim[1], origYMin), min(ylim[0], origYMax) + else: + ylim = max(ylim[0], origYMin), min(ylim[1], origYMax) + # print("THOSE LIMITS ARE TO BE COMPARED WITH THE EXTENT") + # print("IN ORDER TO KNOW WHAT IT IS LIMITING THE DISPLAY") + # print("IF THE AXES OR THE EXTENT") + dx, dy = xlim[1] - xlim[0], ylim[1] - ylim[0] + + y0 = max(0, ylim[0] - 5) + y1 = min(self._full_res.shape[0], ylim[1] + 5) + x0 = max(0, xlim[0] - 5) + x1 = min(self._full_res.shape[1], xlim[1] + 5) + y0, y1, x0, x1 = [int(a) for a in [y0, y1, x0, x1]] + + sy = int(max(1, min((y1 - y0) / 5., numpy.ceil(dy / ext[1])))) + sx = int(max(1, min((x1 - x0) / 5., numpy.ceil(dx / ext[0])))) + + # have we already calculated what we need? + if (self._sx is not None) and (self._sy is not None): + if (sx >= self._sx and sy >= self._sy and + x0 >= self._bounds[0] and x1 <= self._bounds[1] and + y0 >= self._bounds[2] and y1 <= self._bounds[3]): + return + + self._A = self._full_res[y0:y1:sy, x0:x1:sx] + self._A = cbook.safe_masked_invalid(self._A) + x1 = x0 + self._A.shape[1] * sx + y1 = y0 + self._A.shape[0] * sy + + if self.origin == "upper": + self.set_extent([x0, x1, y1, y0]) + else: + self.set_extent([x0, x1, y0, y1]) + self._sx = sx + self._sy = sy + self._bounds = (x0, x1, y0, y1) + self.changed() + + def draw(self, renderer, *args, **kwargs): + self._scale_to_res() + super(ModestImage, self).draw(renderer, *args, **kwargs) diff --git a/silx/gui/plot/backends/__init__.py b/silx/gui/plot/backends/__init__.py new file mode 100644 index 0000000..966d9df --- /dev/null +++ b/silx/gui/plot/backends/__init__.py @@ -0,0 +1,29 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This package implements the backend of the Plot.""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "21/03/2017" diff --git a/silx/gui/plot/backends/_matplotlib.py b/silx/gui/plot/backends/_matplotlib.py new file mode 100644 index 0000000..26732a0 --- /dev/null +++ b/silx/gui/plot/backends/_matplotlib.py @@ -0,0 +1,64 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module inits matplotlib and setups the backend to use. + +It MUST be imported prior to any other import of matplotlib. + +It provides the matplotlib :class:`FigureCanvasQTAgg` class corresponding +to the used backend. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "26/10/2016" + + +import sys +import logging + + +_logger = logging.getLogger(__name__) + +if 'matplotlib' in sys.modules: + _logger.warning( + 'matplotlib already loaded, setting its backend may not work') + + +from ... import qt + +import matplotlib + +if qt.BINDING == 'PySide': + matplotlib.rcParams['backend'] = 'Qt4Agg' + matplotlib.rcParams['backend.qt4'] = 'PySide' + from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg # noqa + +elif qt.BINDING == 'PyQt4': + matplotlib.rcParams['backend'] = 'Qt4Agg' + from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg # noqa + +elif qt.BINDING == 'PyQt5': + matplotlib.rcParams['backend'] = 'Qt5Agg' + from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg # noqa diff --git a/silx/gui/plot/backends/glutils/GLPlotCurve.py b/silx/gui/plot/backends/glutils/GLPlotCurve.py new file mode 100644 index 0000000..4f08054 --- /dev/null +++ b/silx/gui/plot/backends/glutils/GLPlotCurve.py @@ -0,0 +1,1317 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +""" +This module provides classes to render 2D lines and scatter plots +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "03/04/2017" + + +import math +import logging + +import numpy + +from silx.math.combo import min_max + +from ...._glutils import gl +from ...._glutils import numpyToGLType, Program, vertexBuffer +from ..._utils import FLOAT32_MINPOS +from .GLSupport import buildFillMaskIndices + + +_logger = logging.getLogger(__name__) + + +_MPL_NONES = None, 'None', '', ' ' + + +# fill ######################################################################## + +class _Fill2D(object): + _LINEAR, _LOG10_X, _LOG10_Y, _LOG10_X_Y = 0, 1, 2, 3 + + _SHADERS = { + 'vertexTransforms': { + _LINEAR: """ + vec4 transformXY(float x, float y) { + return vec4(x, y, 0.0, 1.0); + } + """, + _LOG10_X: """ + const float oneOverLog10 = 0.43429448190325176; + + vec4 transformXY(float x, float y) { + return vec4(oneOverLog10 * log(x), y, 0.0, 1.0); + } + """, + _LOG10_Y: """ + const float oneOverLog10 = 0.43429448190325176; + + vec4 transformXY(float x, float y) { + return vec4(x, oneOverLog10 * log(y), 0.0, 1.0); + } + """, + _LOG10_X_Y: """ + const float oneOverLog10 = 0.43429448190325176; + + vec4 transformXY(float x, float y) { + return vec4(oneOverLog10 * log(x), + oneOverLog10 * log(y), + 0.0, 1.0); + } + """ + }, + 'vertex': """ + #version 120 + + uniform mat4 matrix; + attribute float xPos; + attribute float yPos; + + %s + + void main(void) { + gl_Position = matrix * transformXY(xPos, yPos); + } + """, + 'fragment': """ + #version 120 + + uniform vec4 color; + + void main(void) { + gl_FragColor = color; + } + """ + } + + _programs = { + _LINEAR: Program( + _SHADERS['vertex'] % _SHADERS['vertexTransforms'][_LINEAR], + _SHADERS['fragment'], attrib0='xPos'), + _LOG10_X: Program( + _SHADERS['vertex'] % _SHADERS['vertexTransforms'][_LOG10_X], + _SHADERS['fragment'], attrib0='xPos'), + _LOG10_Y: Program( + _SHADERS['vertex'] % _SHADERS['vertexTransforms'][_LOG10_Y], + _SHADERS['fragment'], attrib0='xPos'), + _LOG10_X_Y: Program( + _SHADERS['vertex'] % _SHADERS['vertexTransforms'][_LOG10_X_Y], + _SHADERS['fragment'], attrib0='xPos'), + } + + def __init__(self, xFillVboData=None, yFillVboData=None, + xMin=None, yMin=None, xMax=None, yMax=None, + color=(0., 0., 0., 1.)): + self.xFillVboData = xFillVboData + self.yFillVboData = yFillVboData + self.xMin, self.yMin = xMin, yMin + self.xMax, self.yMax = xMax, yMax + self.color = color + + self._bboxVertices = None + self._indices = None + self._indicesType = None + + def prepare(self): + if self._indices is None: + self._indices = buildFillMaskIndices(self.xFillVboData.size) + self._indicesType = numpyToGLType(self._indices.dtype) + + if self._bboxVertices is None: + yMin, yMax = min(self.yMin, 1e-32), max(self.yMax, 1e-32) + self._bboxVertices = numpy.array(((self.xMin, self.xMin, + self.xMax, self.xMax), + (yMin, yMax, yMin, yMax)), + dtype=numpy.float32) + + def render(self, matrix, isXLog, isYLog): + self.prepare() + + if isXLog: + transform = self._LOG10_X_Y if isYLog else self._LOG10_X + else: + transform = self._LOG10_Y if isYLog else self._LINEAR + + prog = self._programs[transform] + prog.use() + + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matrix) + + gl.glUniform4f(prog.uniforms['color'], *self.color) + + xPosAttrib = prog.attributes['xPos'] + yPosAttrib = prog.attributes['yPos'] + + gl.glEnableVertexAttribArray(xPosAttrib) + self.xFillVboData.setVertexAttrib(xPosAttrib) + + gl.glEnableVertexAttribArray(yPosAttrib) + self.yFillVboData.setVertexAttrib(yPosAttrib) + + # Prepare fill mask + gl.glEnable(gl.GL_STENCIL_TEST) + gl.glStencilMask(1) + gl.glStencilFunc(gl.GL_ALWAYS, 1, 1) + gl.glStencilOp(gl.GL_INVERT, gl.GL_INVERT, gl.GL_INVERT) + gl.glColorMask(gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE) + gl.glDepthMask(gl.GL_FALSE) + + gl.glDrawElements(gl.GL_TRIANGLE_STRIP, self._indices.size, + self._indicesType, self._indices) + + gl.glStencilFunc(gl.GL_EQUAL, 1, 1) + # Reset stencil while drawing + gl.glStencilOp(gl.GL_ZERO, gl.GL_ZERO, gl.GL_ZERO) + gl.glColorMask(gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE) + gl.glDepthMask(gl.GL_TRUE) + + gl.glVertexAttribPointer(xPosAttrib, 1, gl.GL_FLOAT, gl.GL_FALSE, 0, + self._bboxVertices[0]) + gl.glVertexAttribPointer(yPosAttrib, 1, gl.GL_FLOAT, gl.GL_FALSE, 0, + self._bboxVertices[1]) + gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, self._bboxVertices[0].size) + + gl.glDisable(gl.GL_STENCIL_TEST) + + +# line ######################################################################## + +SOLID, DASHED, DASHDOT, DOTTED = '-', '--', '-.', ':' + + +class _Lines2D(object): + STYLES = SOLID, DASHED, DASHDOT, DOTTED + """Supported line styles""" + + _LINEAR, _LOG10_X, _LOG10_Y, _LOG10_X_Y = 0, 1, 2, 3 + + _SHADERS = { + 'vertexTransforms': { + _LINEAR: """ + vec4 transformXY(float x, float y) { + return vec4(x, y, 0.0, 1.0); + } + """, + _LOG10_X: """ + const float oneOverLog10 = 0.43429448190325176; + + vec4 transformXY(float x, float y) { + return vec4(oneOverLog10 * log(x), y, 0.0, 1.0); + } + """, + _LOG10_Y: """ + const float oneOverLog10 = 0.43429448190325176; + + vec4 transformXY(float x, float y) { + return vec4(x, oneOverLog10 * log(y), 0.0, 1.0); + } + """, + _LOG10_X_Y: """ + const float oneOverLog10 = 0.43429448190325176; + + vec4 transformXY(float x, float y) { + return vec4(oneOverLog10 * log(x), + oneOverLog10 * log(y), + 0.0, 1.0); + } + """ + }, + 'solid': { + 'vertex': """ + #version 120 + + uniform mat4 matrix; + attribute float xPos; + attribute float yPos; + attribute vec4 color; + + varying vec4 vColor; + + %s + + void main(void) { + gl_Position = matrix * transformXY(xPos, yPos); + vColor = color; + } + """, + 'fragment': """ + #version 120 + + varying vec4 vColor; + + void main(void) { + gl_FragColor = vColor; + } + """ + }, + + + # Limitation: Dash using an estimate of distance in screen coord + # to avoid computing distance when viewport is resized + # results in inequal dashes when viewport aspect ratio is far from 1 + 'dashed': { + 'vertex': """ + #version 120 + + uniform mat4 matrix; + uniform vec2 halfViewportSize; + attribute float xPos; + attribute float yPos; + attribute vec4 color; + attribute float distance; + + varying float vDist; + varying vec4 vColor; + + %s + + void main(void) { + gl_Position = matrix * transformXY(xPos, yPos); + //Estimate distance in pixels + vec2 probe = vec2(matrix * vec4(1., 1., 0., 0.)) * + halfViewportSize; + float pixelPerDataEstimate = length(probe)/sqrt(2.); + vDist = distance * pixelPerDataEstimate; + vColor = color; + } + """, + 'fragment': """ + #version 120 + + /* Dashes: [0, x], [y, z] + Dash period: w */ + uniform vec4 dash; + + varying float vDist; + varying vec4 vColor; + + void main(void) { + float dist = mod(vDist, dash.w); + if ((dist > dash.x && dist < dash.y) || dist > dash.z) { + discard; + } + gl_FragColor = vColor; + } + """ + } + } + + _programs = {} + + def __init__(self, xVboData=None, yVboData=None, + colorVboData=None, distVboData=None, + style=SOLID, color=(0., 0., 0., 1.), + width=1, dashPeriod=20, drawMode=None): + self.xVboData = xVboData + self.yVboData = yVboData + self.distVboData = distVboData + self.colorVboData = colorVboData + self.useColorVboData = colorVboData is not None + + self.color = color + self._width = 1 + self.width = width + self._style = None + self.style = style + self.dashPeriod = dashPeriod + + self._drawMode = drawMode if drawMode is not None else gl.GL_LINE_STRIP + + @property + def style(self): + return self._style + + @style.setter + def style(self, style): + if style in _MPL_NONES: + self._style = None + self.render = self._renderNone + else: + assert style in self.STYLES + self._style = style + if style == SOLID: + self.render = self._renderSolid + else: # DASHED, DASHDOT, DOTTED + self.render = self._renderDash + + @property + def width(self): + return self._width + + @width.setter + def width(self, width): + # try: + # widthRange = self._widthRange + # except AttributeError: + # widthRange = gl.glGetFloatv(gl.GL_ALIASED_LINE_WIDTH_RANGE) + # # Shared among contexts, this should be enough.. + # _Lines2D._widthRange = widthRange + # assert width >= widthRange[0] and width <= widthRange[1] + self._width = width + + @classmethod + def _getProgram(cls, transform, style): + try: + prgm = cls._programs[(transform, style)] + except KeyError: + sources = cls._SHADERS[style] + vertexShdr = sources['vertex'] % \ + cls._SHADERS['vertexTransforms'][transform] + prgm = Program(vertexShdr, sources['fragment'], attrib0='xPos') + cls._programs[(transform, style)] = prgm + return prgm + + @classmethod + def init(cls): + gl.glHint(gl.GL_LINE_SMOOTH_HINT, gl.GL_NICEST) + + def _renderNone(self, matrix, isXLog, isYLog): + pass + + render = _renderNone # Overridden in style setter + + def _renderSolid(self, matrix, isXLog, isYLog): + if isXLog: + transform = self._LOG10_X_Y if isYLog else self._LOG10_X + else: + transform = self._LOG10_Y if isYLog else self._LINEAR + + prog = self._getProgram(transform, 'solid') + prog.use() + + gl.glEnable(gl.GL_LINE_SMOOTH) + + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matrix) + + colorAttrib = prog.attributes['color'] + if self.useColorVboData and self.colorVboData is not None: + gl.glEnableVertexAttribArray(colorAttrib) + self.colorVboData.setVertexAttrib(colorAttrib) + else: + gl.glDisableVertexAttribArray(colorAttrib) + gl.glVertexAttrib4f(colorAttrib, *self.color) + + xPosAttrib = prog.attributes['xPos'] + gl.glEnableVertexAttribArray(xPosAttrib) + self.xVboData.setVertexAttrib(xPosAttrib) + + yPosAttrib = prog.attributes['yPos'] + gl.glEnableVertexAttribArray(yPosAttrib) + self.yVboData.setVertexAttrib(yPosAttrib) + + gl.glLineWidth(self.width) + gl.glDrawArrays(self._drawMode, 0, self.xVboData.size) + + gl.glDisable(gl.GL_LINE_SMOOTH) + + def _renderDash(self, matrix, isXLog, isYLog): + if isXLog: + transform = self._LOG10_X_Y if isYLog else self._LOG10_X + else: + transform = self._LOG10_Y if isYLog else self._LINEAR + + prog = self._getProgram(transform, 'dashed') + prog.use() + + gl.glEnable(gl.GL_LINE_SMOOTH) + + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matrix) + x, y, viewWidth, viewHeight = gl.glGetFloatv(gl.GL_VIEWPORT) + gl.glUniform2f(prog.uniforms['halfViewportSize'], + 0.5 * viewWidth, 0.5 * viewHeight) + + if self.style == DOTTED: + dash = (0.1 * self.dashPeriod, + 0.6 * self.dashPeriod, + 0.7 * self.dashPeriod, + self.dashPeriod) + elif self.style == DASHDOT: + dash = (0.3 * self.dashPeriod, + 0.5 * self.dashPeriod, + 0.6 * self.dashPeriod, + self.dashPeriod) + else: + dash = (0.5 * self.dashPeriod, + self.dashPeriod, + self.dashPeriod, + self.dashPeriod) + + gl.glUniform4f(prog.uniforms['dash'], *dash) + + colorAttrib = prog.attributes['color'] + if self.useColorVboData and self.colorVboData is not None: + gl.glEnableVertexAttribArray(colorAttrib) + self.colorVboData.setVertexAttrib(colorAttrib) + else: + gl.glDisableVertexAttribArray(colorAttrib) + gl.glVertexAttrib4f(colorAttrib, *self.color) + + distAttrib = prog.attributes['distance'] + gl.glEnableVertexAttribArray(distAttrib) + self.distVboData.setVertexAttrib(distAttrib) + + xPosAttrib = prog.attributes['xPos'] + gl.glEnableVertexAttribArray(xPosAttrib) + self.xVboData.setVertexAttrib(xPosAttrib) + + yPosAttrib = prog.attributes['yPos'] + gl.glEnableVertexAttribArray(yPosAttrib) + self.yVboData.setVertexAttrib(yPosAttrib) + + gl.glLineWidth(self.width) + gl.glDrawArrays(self._drawMode, 0, self.xVboData.size) + + gl.glDisable(gl.GL_LINE_SMOOTH) + + +def _distancesFromArrays(xData, yData): + deltas = numpy.dstack(( + numpy.ediff1d(xData, to_begin=numpy.float32(0.)), + numpy.ediff1d(yData, to_begin=numpy.float32(0.))))[0] + return numpy.cumsum(numpy.sqrt(numpy.sum(deltas ** 2, axis=1))) + + +# points ###################################################################### + +DIAMOND, CIRCLE, SQUARE, PLUS, X_MARKER, POINT, PIXEL, ASTERISK = \ + 'd', 'o', 's', '+', 'x', '.', ',', '*' + +H_LINE, V_LINE = '_', '|' + + +class _Points2D(object): + MARKERS = (DIAMOND, CIRCLE, SQUARE, PLUS, X_MARKER, POINT, PIXEL, ASTERISK, + H_LINE, V_LINE) + + _LINEAR, _LOG10_X, _LOG10_Y, _LOG10_X_Y = 0, 1, 2, 3 + + _SHADERS = { + 'vertexTransforms': { + _LINEAR: """ + vec4 transformXY(float x, float y) { + return vec4(x, y, 0.0, 1.0); + } + """, + _LOG10_X: """ + const float oneOverLog10 = 0.43429448190325176; + + vec4 transformXY(float x, float y) { + return vec4(oneOverLog10 * log(x), y, 0.0, 1.0); + } + """, + _LOG10_Y: """ + const float oneOverLog10 = 0.43429448190325176; + + vec4 transformXY(float x, float y) { + return vec4(x, oneOverLog10 * log(y), 0.0, 1.0); + } + """, + _LOG10_X_Y: """ + const float oneOverLog10 = 0.43429448190325176; + + vec4 transformXY(float x, float y) { + return vec4(oneOverLog10 * log(x), + oneOverLog10 * log(y), + 0.0, 1.0); + } + """ + }, + 'vertex': """ + #version 120 + + uniform mat4 matrix; + uniform int transform; + uniform float size; + attribute float xPos; + attribute float yPos; + attribute vec4 color; + + varying vec4 vColor; + + %s + + void main(void) { + gl_Position = matrix * transformXY(xPos, yPos); + vColor = color; + gl_PointSize = size; + } + """, + + 'fragmentSymbols': { + DIAMOND: """ + float alphaSymbol(vec2 coord, float size) { + vec2 centerCoord = abs(coord - vec2(0.5, 0.5)); + float f = centerCoord.x + centerCoord.y; + return clamp(size * (0.5 - f), 0.0, 1.0); + } + """, + CIRCLE: """ + float alphaSymbol(vec2 coord, float size) { + float radius = 0.5; + float r = distance(coord, vec2(0.5, 0.5)); + return clamp(size * (radius - r), 0.0, 1.0); + } + """, + SQUARE: """ + float alphaSymbol(vec2 coord, float size) { + return 1.0; + } + """, + PLUS: """ + float alphaSymbol(vec2 coord, float size) { + vec2 d = abs(size * (coord - vec2(0.5, 0.5))); + if (min(d.x, d.y) < 0.5) { + return 1.0; + } else { + return 0.0; + } + } + """, + X_MARKER: """ + float alphaSymbol(vec2 coord, float size) { + vec2 pos = floor(size * coord) + 0.5; + vec2 d_x = abs(pos.x + vec2(- pos.y, pos.y - size)); + if (min(d_x.x, d_x.y) <= 0.5) { + return 1.0; + } else { + return 0.0; + } + } + """, + ASTERISK: """ + float alphaSymbol(vec2 coord, float size) { + /* Combining +, x and cirle */ + vec2 d_plus = abs(size * (coord - vec2(0.5, 0.5))); + vec2 pos = floor(size * coord) + 0.5; + vec2 d_x = abs(pos.x + vec2(- pos.y, pos.y - size)); + if (min(d_plus.x, d_plus.y) < 0.5) { + return 1.0; + } else if (min(d_x.x, d_x.y) <= 0.5) { + float r = distance(coord, vec2(0.5, 0.5)); + return clamp(size * (0.5 - r), 0.0, 1.0); + } else { + return 0.0; + } + } + """, + H_LINE: """ + float alphaSymbol(vec2 coord, float size) { + float dy = abs(size * (coord.y - 0.5)); + if (dy < 0.5) { + return 1.0; + } else { + return 0.0; + } + } + """, + V_LINE: """ + float alphaSymbol(vec2 coord, float size) { + float dx = abs(size * (coord.x - 0.5)); + if (dx < 0.5) { + return 1.0; + } else { + return 0.0; + } + } + """ + }, + + 'fragment': """ + #version 120 + + uniform float size; + + varying vec4 vColor; + + %s + + void main(void) { + float alpha = alphaSymbol(gl_PointCoord, size); + if (alpha <= 0.0) { + discard; + } else { + gl_FragColor = vec4(vColor.rgb, alpha * clamp(vColor.a, 0.0, 1.0)); + } + } + """ + } + + _programs = {} + + def __init__(self, xVboData=None, yVboData=None, colorVboData=None, + marker=SQUARE, color=(0., 0., 0., 1.), size=7): + self.color = color + self._marker = None + self.marker = marker + self._size = 1 + self.size = size + + self.xVboData = xVboData + self.yVboData = yVboData + self.colorVboData = colorVboData + self.useColorVboData = colorVboData is not None + + @property + def marker(self): + return self._marker + + @marker.setter + def marker(self, marker): + if marker in _MPL_NONES: + self._marker = None + self.render = self._renderNone + else: + assert marker in self.MARKERS + self._marker = marker + self.render = self._renderMarkers + + @property + def size(self): + return self._size + + @size.setter + def size(self, size): + # try: + # sizeRange = self._sizeRange + # except AttributeError: + # sizeRange = gl.glGetFloatv(gl.GL_POINT_SIZE_RANGE) + # # Shared among contexts, this should be enough.. + # _Points2D._sizeRange = sizeRange + # assert size >= sizeRange[0] and size <= sizeRange[1] + self._size = size + + @classmethod + def _getProgram(cls, transform, marker): + """On-demand shader program creation.""" + if marker == PIXEL: + marker = SQUARE + elif marker == POINT: + marker = CIRCLE + try: + prgm = cls._programs[(transform, marker)] + except KeyError: + vertShdr = cls._SHADERS['vertex'] % \ + cls._SHADERS['vertexTransforms'][transform] + fragShdr = cls._SHADERS['fragment'] % \ + cls._SHADERS['fragmentSymbols'][marker] + prgm = Program(vertShdr, fragShdr, attrib0='xPos') + + cls._programs[(transform, marker)] = prgm + return prgm + + @classmethod + def init(cls): + version = gl.glGetString(gl.GL_VERSION) + majorVersion = int(version[0]) + assert majorVersion >= 2 + gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2 + gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2 + if majorVersion >= 3: # OpenGL 3 + gl.glEnable(gl.GL_PROGRAM_POINT_SIZE) + + def _renderNone(self, matrix, isXLog, isYLog): + pass + + render = _renderNone + + def _renderMarkers(self, matrix, isXLog, isYLog): + if isXLog: + transform = self._LOG10_X_Y if isYLog else self._LOG10_X + else: + transform = self._LOG10_Y if isYLog else self._LINEAR + + prog = self._getProgram(transform, self.marker) + prog.use() + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matrix) + if self.marker == PIXEL: + size = 1 + elif self.marker == POINT: + size = math.ceil(0.5 * self.size) + 1 # Mimic Matplotlib point + else: + size = self.size + gl.glUniform1f(prog.uniforms['size'], size) + # gl.glPointSize(self.size) + + cAttrib = prog.attributes['color'] + if self.useColorVboData and self.colorVboData is not None: + gl.glEnableVertexAttribArray(cAttrib) + self.colorVboData.setVertexAttrib(cAttrib) + else: + gl.glDisableVertexAttribArray(cAttrib) + gl.glVertexAttrib4f(cAttrib, *self.color) + + xAttrib = prog.attributes['xPos'] + gl.glEnableVertexAttribArray(xAttrib) + self.xVboData.setVertexAttrib(xAttrib) + + yAttrib = prog.attributes['yPos'] + gl.glEnableVertexAttribArray(yAttrib) + self.yVboData.setVertexAttrib(yAttrib) + + gl.glDrawArrays(gl.GL_POINTS, 0, self.xVboData.size) + + gl.glUseProgram(0) + + +# error bars ################################################################## + +class _ErrorBars(object): + """Display errors bars. + + This is using its own VBO as opposed to fill/points/lines. + There is no picking on error bars. + As is, there is no way to update data and errors, but it handles + log scales by removing data <= 0 and clipping error bars to positive + range. + + It uses 2 vertices per error bars and uses :class:`_Lines2D` to + render error bars and :class:`_Points2D` to render the ends. + """ + + def __init__(self, xData, yData, xError, yError, + xMin, yMin, + color=(0., 0., 0., 1.)): + """Initialization. + + :param numpy.ndarray xData: X coordinates of the data. + :param numpy.ndarray yData: Y coordinates of the data. + :param xError: The absolute error on the X axis. + :type xError: A float, or a numpy.ndarray of float32. + If it is an array, it can either be a 1D array of + same length as the data or a 2D array with 2 rows + of same length as the data: row 0 for negative errors, + row 1 for positive errors. + :param yError: The absolute error on the Y axis. + :type yError: A float, or a numpy.ndarray of float32. See xError. + :param float xMin: The min X value already computed by GLPlotCurve2D. + :param float yMin: The min Y value already computed by GLPlotCurve2D. + :param color: The color to use for both lines and ending points. + :type color: tuple of 4 floats + """ + self._attribs = None + self._isXLog, self._isYLog = False, False + self._xMin, self._yMin = xMin, yMin + + if xError is not None or yError is not None: + assert len(xData) == len(yData) + self._xData = numpy.array( + xData, order='C', dtype=numpy.float32, copy=False) + self._yData = numpy.array( + yData, order='C', dtype=numpy.float32, copy=False) + + # This also works if xError, yError is a float/int + self._xError = numpy.array( + xError, order='C', dtype=numpy.float32, copy=False) + self._yError = numpy.array( + yError, order='C', dtype=numpy.float32, copy=False) + else: + self._xData, self._yData = None, None + self._xError, self._yError = None, None + + self._lines = _Lines2D(None, None, color=color, drawMode=gl.GL_LINES) + self._xErrPoints = _Points2D(None, None, color=color, marker=V_LINE) + self._yErrPoints = _Points2D(None, None, color=color, marker=H_LINE) + + def _positiveValueFilter(self, onlyXPos, onlyYPos): + """Filter data (x, y) and errors (xError, yError) to remove + negative and null data values on required axis (onlyXPos, onlyYPos). + + Returned arrays might be NOT contiguous. + + :return: Filtered xData, yData, xError and yError arrays. + """ + if ((not onlyXPos or self._xMin > 0.) and + (not onlyYPos or self._yMin > 0.)): + # No need to filter, all values are > 0 on log axes + return self._xData, self._yData, self._xError, self._yError + + _logger.warning( + 'Removing values <= 0 of curve with error bars on a log axis.') + + x, y = self._xData, self._yData + xError, yError = self._xError, self._yError + + # First remove negative data + if onlyXPos and onlyYPos: + mask = (x > 0.) & (y > 0.) + elif onlyXPos: + mask = x > 0. + else: # onlyYPos + mask = y > 0. + x, y = x[mask], y[mask] + + # Remove corresponding values from error arrays + if xError is not None and xError.size != 1: + if len(xError.shape) == 1: + xError = xError[mask] + else: # 2 rows + xError = xError[:, mask] + if yError is not None and yError.size != 1: + if len(yError.shape) == 1: + yError = yError[mask] + else: # 2 rows + yError = yError[:, mask] + + return x, y, xError, yError + + def _buildVertices(self, isXLog, isYLog): + """Generates error bars vertices according to log scales.""" + xData, yData, xError, yError = self._positiveValueFilter( + isXLog, isYLog) + + nbLinesPerDataPts = 1 if xError is not None else 0 + nbLinesPerDataPts += 1 if yError is not None else 0 + + nbDataPts = len(xData) + + # interleave coord+error, coord-error. + # xError vertices first if any, then yError vertices if any. + xCoords = numpy.empty(nbDataPts * nbLinesPerDataPts * 2, + dtype=numpy.float32) + yCoords = numpy.empty(nbDataPts * nbLinesPerDataPts * 2, + dtype=numpy.float32) + + if xError is not None: # errors on the X axis + if len(xError.shape) == 2: + xErrorMinus, xErrorPlus = xError[0], xError[1] + else: + # numpy arrays of len 1 or len(xData) + xErrorMinus, xErrorPlus = xError, xError + + # Interleave vertices for xError + endXError = 2 * nbDataPts + xCoords[0:endXError-1:2] = xData + xErrorPlus + + minValues = xData - xErrorMinus + if isXLog: + # Clip min bounds to positive value + minValues[minValues <= 0] = FLOAT32_MINPOS + xCoords[1:endXError:2] = minValues + + yCoords[0:endXError-1:2] = yData + yCoords[1:endXError:2] = yData + else: + endXError = 0 + + if yError is not None: # errors on the Y axis + if len(yError.shape) == 2: + yErrorMinus, yErrorPlus = yError[0], yError[1] + else: + # numpy arrays of len 1 or len(yData) + yErrorMinus, yErrorPlus = yError, yError + + # Interleave vertices for yError + xCoords[endXError::2] = xData + xCoords[endXError+1::2] = xData + yCoords[endXError::2] = yData + yErrorPlus + minValues = yData - yErrorMinus + if isYLog: + # Clip min bounds to positive value + minValues[minValues <= 0] = FLOAT32_MINPOS + yCoords[endXError+1::2] = minValues + + return xCoords, yCoords + + def prepare(self, isXLog, isYLog): + if self._xData is None: + return + + if self._isXLog != isXLog or self._isYLog != isYLog: + # Log state has changed + self._isXLog, self._isYLog = isXLog, isYLog + + self.discard() # discard existing VBOs + + if self._attribs is None: + xCoords, yCoords = self._buildVertices(isXLog, isYLog) + + xAttrib, yAttrib = vertexBuffer((xCoords, yCoords)) + self._attribs = xAttrib, yAttrib + + self._lines.xVboData, self._lines.yVboData = xAttrib, yAttrib + + # Set xError points using the same VBO as lines + self._xErrPoints.xVboData = xAttrib.copy() + self._xErrPoints.xVboData.size //= 2 + self._xErrPoints.yVboData = yAttrib.copy() + self._xErrPoints.yVboData.size //= 2 + + # Set yError points using the same VBO as lines + self._yErrPoints.xVboData = xAttrib.copy() + self._yErrPoints.xVboData.size //= 2 + self._yErrPoints.xVboData.offset += (xAttrib.itemsize * + xAttrib.size // 2) + self._yErrPoints.yVboData = yAttrib.copy() + self._yErrPoints.yVboData.size //= 2 + self._yErrPoints.yVboData.offset += (yAttrib.itemsize * + yAttrib.size // 2) + + def render(self, matrix, isXLog, isYLog): + if self._attribs is not None: + self._lines.render(matrix, isXLog, isYLog) + self._xErrPoints.render(matrix, isXLog, isYLog) + self._yErrPoints.render(matrix, isXLog, isYLog) + + def discard(self): + if self._attribs is not None: + self._lines.xVboData, self._lines.yVboData = None, None + self._xErrPoints.xVboData, self._xErrPoints.yVboData = None, None + self._yErrPoints.xVboData, self._yErrPoints.yVboData = None, None + self._attribs[0].vbo.discard() + self._attribs = None + + +# curves ###################################################################### + +def _proxyProperty(*componentsAttributes): + """Create a property to access an attribute of attribute(s). + Useful for composition. + Supports multiple components this way: + getter returns the first found, setter sets all + """ + def getter(self): + for compName, attrName in componentsAttributes: + try: + component = getattr(self, compName) + except AttributeError: + pass + else: + return getattr(component, attrName) + + def setter(self, value): + for compName, attrName in componentsAttributes: + component = getattr(self, compName) + setattr(component, attrName, value) + return property(getter, setter) + + +class GLPlotCurve2D(object): + def __init__(self, xData, yData, colorData=None, + xError=None, yError=None, + lineStyle=None, lineColor=None, + lineWidth=None, lineDashPeriod=None, + marker=None, markerColor=None, markerSize=None, + fillColor=None): + self._isXLog = False + self._isYLog = False + self.xData, self.yData, self.colorData = xData, yData, colorData + + if fillColor is not None: + self.fill = _Fill2D(color=fillColor) + else: + self.fill = None + + # Compute x bounds + if xError is None: + result = min_max(xData, min_positive=True) + self.xMin = result.minimum + self.xMinPos = result.min_positive + self.xMax = result.maximum + else: + # Takes the error into account + if hasattr(xError, 'shape') and len(xError.shape) == 2: + xErrorPlus, xErrorMinus = xError[0], xError[1] + else: + xErrorPlus, xErrorMinus = xError, xError + result = min_max(xData - xErrorMinus, min_positive=True) + self.xMin = result.minimum + self.xMinPos = result.min_positive + self.xMax = (xData + xErrorPlus).max() + + # Compute y bounds + if yError is None: + result = min_max(yData, min_positive=True) + self.yMin = result.minimum + self.yMinPos = result.min_positive + self.yMax = result.maximum + else: + # Takes the error into account + if hasattr(yError, 'shape') and len(yError.shape) == 2: + yErrorPlus, yErrorMinus = yError[0], yError[1] + else: + yErrorPlus, yErrorMinus = yError, yError + result = min_max(yData - yErrorMinus, min_positive=True) + self.yMin = result.minimum + self.yMinPos = result.min_positive + self.yMax = (yData + yErrorPlus).max() + + self._errorBars = _ErrorBars(xData, yData, xError, yError, + self.xMin, self.yMin) + + kwargs = {'style': lineStyle} + if lineColor is not None: + kwargs['color'] = lineColor + if lineWidth is not None: + kwargs['width'] = lineWidth + if lineDashPeriod is not None: + kwargs['dashPeriod'] = lineDashPeriod + self.lines = _Lines2D(**kwargs) + + kwargs = {'marker': marker} + if markerColor is not None: + kwargs['color'] = markerColor + if markerSize is not None: + kwargs['size'] = markerSize + self.points = _Points2D(**kwargs) + + xVboData = _proxyProperty(('lines', 'xVboData'), ('points', 'xVboData')) + + yVboData = _proxyProperty(('lines', 'yVboData'), ('points', 'yVboData')) + + colorVboData = _proxyProperty(('lines', 'colorVboData'), + ('points', 'colorVboData')) + + useColorVboData = _proxyProperty(('lines', 'useColorVboData'), + ('points', 'useColorVboData')) + + distVboData = _proxyProperty(('lines', 'distVboData')) + + lineStyle = _proxyProperty(('lines', 'style')) + + lineColor = _proxyProperty(('lines', 'color')) + + lineWidth = _proxyProperty(('lines', 'width')) + + lineDashPeriod = _proxyProperty(('lines', 'dashPeriod')) + + marker = _proxyProperty(('points', 'marker')) + + markerColor = _proxyProperty(('points', 'color')) + + markerSize = _proxyProperty(('points', 'size')) + + @classmethod + def init(cls): + _Lines2D.init() + _Points2D.init() + + @staticmethod + def _logFilterData(x, y, color=None, xLog=False, yLog=False): + # Copied from Plot.py + if xLog and yLog: + idx = numpy.nonzero((x > 0) & (y > 0))[0] + x = numpy.take(x, idx) + y = numpy.take(y, idx) + elif yLog: + idx = numpy.nonzero(y > 0)[0] + x = numpy.take(x, idx) + y = numpy.take(y, idx) + elif xLog: + idx = numpy.nonzero(x > 0)[0] + x = numpy.take(x, idx) + y = numpy.take(y, idx) + else: + idx = None + + if idx is not None and isinstance(color, numpy.ndarray): + colors = numpy.zeros((x.size, 4), color.dtype) + colors[:, 0] = color[idx, 0] + colors[:, 1] = color[idx, 1] + colors[:, 2] = color[idx, 2] + colors[:, 3] = color[idx, 3] + else: + colors = color + return x, y, colors + + def prepare(self, isXLog, isYLog): + # init only supports updating isXLog, isYLog + xData, yData, colorData = self.xData, self.yData, self.colorData + + if self._isXLog != isXLog or self._isYLog != isYLog: + # Log state has changed + self._isXLog, self._isYLog = isXLog, isYLog + + # Check if data <= 0. with log scale + if (isXLog and self.xMin <= 0.) or (isYLog and self.yMin <= 0.): + # Filtering data is needed + xData, yData, colorData = self._logFilterData( + self.xData, self.yData, self.colorData, + self._isXLog, self._isYLog) + + self.discard() # discard existing VBOs + + if self.xVboData is None: + xAttrib, yAttrib, cAttrib, dAttrib = None, None, None, None + if self.lineStyle in (DASHED, DASHDOT, DOTTED): + dists = _distancesFromArrays(xData, yData) + if self.colorData is None: + xAttrib, yAttrib, dAttrib = vertexBuffer( + (xData, yData, dists), + prefix=(1, 1, 0), suffix=(1, 1, 0)) + else: + xAttrib, yAttrib, cAttrib, dAttrib = vertexBuffer( + (xData, yData, colorData, dists), + prefix=(1, 1, 0, 0), suffix=(1, 1, 0, 0)) + elif self.colorData is None: + xAttrib, yAttrib = vertexBuffer( + (xData, yData), prefix=(1, 1), suffix=(1, 1)) + else: + xAttrib, yAttrib, cAttrib = vertexBuffer( + (xData, yData, colorData), prefix=(1, 1, 0)) + + # Shrink VBO + self.xVboData = xAttrib.copy() + self.xVboData.size -= 2 + self.xVboData.offset += xAttrib.itemsize + + self.yVboData = yAttrib.copy() + self.yVboData.size -= 2 + self.yVboData.offset += yAttrib.itemsize + + if cAttrib is not None and colorData.dtype.kind == 'u': + cAttrib.normalisation = True # Normalise uint to [0, 1] + self.colorVboData = cAttrib + self.useColorVboData = cAttrib is not None + self.distVboData = dAttrib + + if self.fill is not None: + xData = xData.reshape(xData.size, 1) + zero = numpy.array((1e-32,), dtype=self.yData.dtype) + + # Add one point before data: (x0, 0.) + xAttrib.vbo.update(xData[0], xAttrib.offset, + xData[0].itemsize) + yAttrib.vbo.update(zero, yAttrib.offset, zero.itemsize) + + # Add one point after data: (xN, 0.) + xAttrib.vbo.update(xData[-1], + xAttrib.offset + + (xAttrib.size - 1) * xAttrib.itemsize, + xData[-1].itemsize) + yAttrib.vbo.update(zero, + yAttrib.offset + + (yAttrib.size - 1) * yAttrib.itemsize, + zero.itemsize) + + self.fill.xFillVboData = xAttrib + self.fill.yFillVboData = yAttrib + self.fill.xMin, self.fill.yMin = self.xMin, self.yMin + self.fill.xMax, self.fill.yMax = self.xMax, self.yMax + + self._errorBars.prepare(isXLog, isYLog) + + def render(self, matrix, isXLog, isYLog): + self.prepare(isXLog, isYLog) + if self.fill is not None: + self.fill.render(matrix, isXLog, isYLog) + self._errorBars.render(matrix, isXLog, isYLog) + self.lines.render(matrix, isXLog, isYLog) + self.points.render(matrix, isXLog, isYLog) + + def discard(self): + if self.xVboData is not None: + self.xVboData.vbo.discard() + + self.xVboData = None + self.yVboData = None + self.colorVboData = None + self.distVboData = None + + self._errorBars.discard() + + def pick(self, xPickMin, yPickMin, xPickMax, yPickMax): + """Perform picking on the curve according to its rendering. + + The picking area is [xPickMin, xPickMax], [yPickMin, yPickMax]. + + In case a segment between 2 points with indices i, i+1 is picked, + only its lower index end point (i.e., i) is added to the result. + In case an end point with index i is picked it is added to the result, + and the segment [i-1, i] is not tested for picking. + + :return: The indices of the picked data + :rtype: list of int + """ + if (self.marker is None and self.lineStyle is None) or \ + self.xMin > xPickMax or xPickMin > self.xMax or \ + self.yMin > yPickMax or yPickMin > self.yMax: + # Note: With log scale the bounding box is too large if + # some data <= 0. + return None + + elif self.lineStyle is not None: + # Using Cohen-Sutherland algorithm for line clipping + codes = ((self.yData > yPickMax) << 3) | \ + ((self.yData < yPickMin) << 2) | \ + ((self.xData > xPickMax) << 1) | \ + (self.xData < xPickMin) + + # Add all points that are inside the picking area + indices = numpy.nonzero(codes == 0)[0].tolist() + + # Segment that might cross the area with no end point inside it + segToTestIdx = numpy.nonzero((codes[:-1] != 0) & + (codes[1:] != 0) & + ((codes[:-1] & codes[1:]) == 0))[0] + + TOP, BOTTOM, RIGHT, LEFT = (1 << 3), (1 << 2), (1 << 1), (1 << 0) + + for index in segToTestIdx: + if index not in indices: + x0, y0 = self.xData[index], self.yData[index] + x1, y1 = self.xData[index + 1], self.yData[index + 1] + code1 = codes[index + 1] + + # check for crossing with horizontal bounds + # y0 == y1 is a never event: + # => pt0 and pt1 in same vertical area are not in segToTest + if code1 & TOP: + x = x0 + (x1 - x0) * (yPickMax - y0) / (y1 - y0) + elif code1 & BOTTOM: + x = x0 + (x1 - x0) * (yPickMin - y0) / (y1 - y0) + else: + x = None # No horizontal bounds intersection test + + if x is not None and xPickMin <= x <= xPickMax: + # Intersection + indices.append(index) + + else: + # check for crossing with vertical bounds + # x0 == x1 is a never event (see remark for y) + if code1 & RIGHT: + y = y0 + (y1 - y0) * (xPickMax - x0) / (x1 - x0) + elif code1 & LEFT: + y = y0 + (y1 - y0) * (xPickMin - x0) / (x1 - x0) + else: + y = None # No vertical bounds intersection test + + if y is not None and yPickMin <= y <= yPickMax: + # Intersection + indices.append(index) + + indices.sort() + + else: + indices = numpy.nonzero((self.xData >= xPickMin) & + (self.xData <= xPickMax) & + (self.yData >= yPickMin) & + (self.yData <= yPickMax))[0].tolist() + + return indices diff --git a/silx/gui/plot/backends/glutils/GLPlotFrame.py b/silx/gui/plot/backends/glutils/GLPlotFrame.py new file mode 100644 index 0000000..367419c --- /dev/null +++ b/silx/gui/plot/backends/glutils/GLPlotFrame.py @@ -0,0 +1,1039 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +""" +This modules provides the rendering of plot titles, axes and grid. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "03/04/2017" + + +# TODO +# keep aspect ratio managed here? +# smarter dirty flag handling? + +import math +import weakref +import logging +from collections import namedtuple + +import numpy + +from ...._glutils import gl, Program +from ..._utils import FLOAT32_SAFE_MIN, FLOAT32_MINPOS, FLOAT32_SAFE_MAX +from .GLSupport import mat4Ortho +from .GLText import Text2D, CENTER, BOTTOM, TOP, LEFT, RIGHT, ROTATE_270 +from ..._utils.ticklayout import niceNumbersAdaptative, niceNumbersForLog10 + + +_logger = logging.getLogger(__name__) + + +# PlotAxis #################################################################### + +class PlotAxis(object): + """Represents a 1D axis of the plot. + This class is intended to be used with :class:`GLPlotFrame`. + """ + + def __init__(self, plot, + tickLength=(0., 0.), + labelAlign=CENTER, labelVAlign=CENTER, + titleAlign=CENTER, titleVAlign=CENTER, + titleRotate=0, titleOffset=(0., 0.)): + self._ticks = None + + self._plot = weakref.ref(plot) + + self._isLog = False + self._dataRange = 1., 100. + self._displayCoords = (0., 0.), (1., 0.) + self._title = '' + + self._tickLength = tickLength + self._labelAlign = labelAlign + self._labelVAlign = labelVAlign + self._titleAlign = titleAlign + self._titleVAlign = titleVAlign + self._titleRotate = titleRotate + self._titleOffset = titleOffset + + @property + def dataRange(self): + """The range of the data represented on the axis as a tuple + of 2 floats: (min, max).""" + return self._dataRange + + @dataRange.setter + def dataRange(self, dataRange): + assert len(dataRange) == 2 + assert dataRange[0] <= dataRange[1] + dataRange = float(dataRange[0]), float(dataRange[1]) + + if dataRange != self._dataRange: + self._dataRange = dataRange + self._dirtyTicks() + + @property + def isLog(self): + """Whether the axis is using a log10 scale or not as a bool.""" + return self._isLog + + @isLog.setter + def isLog(self, isLog): + isLog = bool(isLog) + if isLog != self._isLog: + self._isLog = isLog + self._dirtyTicks() + + @property + def displayCoords(self): + """The coordinates of the start and end points of the axis + in display space (i.e., in pixels) as a tuple of 2 tuples of + 2 floats: ((x0, y0), (x1, y1)). + """ + return self._displayCoords + + @displayCoords.setter + def displayCoords(self, displayCoords): + assert len(displayCoords) == 2 + assert len(displayCoords[0]) == 2 + assert len(displayCoords[1]) == 2 + displayCoords = tuple(displayCoords[0]), tuple(displayCoords[1]) + if displayCoords != self._displayCoords: + self._displayCoords = displayCoords + self._dirtyTicks() + + @property + def title(self): + """The text label associated with this axis as a str in latin-1.""" + return self._title + + @title.setter + def title(self, title): + if title != self._title: + self._title = title + + plot = self._plot() + if plot is not None: + plot._dirty() + + @property + def ticks(self): + """Ticks as tuples: ((x, y) in display, dataPos, textLabel).""" + if self._ticks is None: + self._ticks = tuple(self._ticksGenerator()) + return self._ticks + + def getVerticesAndLabels(self): + """Create the list of vertices for axis and associated text labels. + + :returns: A tuple: List of 2D line vertices, List of Text2D labels. + """ + vertices = list(self.displayCoords) # Add start and end points + labels = [] + tickLabelsSize = [0., 0.] + + xTickLength, yTickLength = self._tickLength + for (xPixel, yPixel), dataPos, text in self.ticks: + if text is None: + tickScale = 0.5 + else: + tickScale = 1. + + label = Text2D(text=text, + x=xPixel - xTickLength, + y=yPixel - yTickLength, + align=self._labelAlign, + valign=self._labelVAlign) + + width, height = label.size + if width > tickLabelsSize[0]: + tickLabelsSize[0] = width + if height > tickLabelsSize[1]: + tickLabelsSize[1] = height + + labels.append(label) + + vertices.append((xPixel, yPixel)) + vertices.append((xPixel + tickScale * xTickLength, + yPixel + tickScale * yTickLength)) + + (x0, y0), (x1, y1) = self.displayCoords + xAxisCenter = 0.5 * (x0 + x1) + yAxisCenter = 0.5 * (y0 + y1) + + xOffset, yOffset = self._titleOffset + + # Adaptative title positioning: + # tickNorm = math.sqrt(xTickLength ** 2 + yTickLength ** 2) + # xOffset = -tickLabelsSize[0] * xTickLength / tickNorm + # xOffset -= 3 * xTickLength + # yOffset = -tickLabelsSize[1] * yTickLength / tickNorm + # yOffset -= 3 * yTickLength + + axisTitle = Text2D(text=self.title, + x=xAxisCenter + xOffset, + y=yAxisCenter + yOffset, + align=self._titleAlign, + valign=self._titleVAlign, + rotate=self._titleRotate) + labels.append(axisTitle) + + return vertices, labels + + def _dirtyTicks(self): + """Mark ticks as dirty and notify listener (i.e., background).""" + self._ticks = None + plot = self._plot() + if plot is not None: + plot._dirty() + + @staticmethod + def _frange(start, stop, step): + """range for float (including stop).""" + while start <= stop: + yield start + start += step + + def _ticksGenerator(self): + """Generator of ticks as tuples: + ((x, y) in display, dataPos, textLabel). + """ + dataMin, dataMax = self.dataRange + if self.isLog and dataMin <= 0.: + _logger.warning( + 'Getting ticks while isLog=True and dataRange[0]<=0.') + dataMin = 1. + if dataMax < dataMin: + dataMax = 1. + + if dataMin != dataMax: # data range is not null + (x0, y0), (x1, y1) = self.displayCoords + + if self.isLog: + logMin, logMax = math.log10(dataMin), math.log10(dataMax) + tickMin, tickMax, step, _ = niceNumbersForLog10(logMin, logMax) + + xScale = (x1 - x0) / (logMax - logMin) + yScale = (y1 - y0) / (logMax - logMin) + + for logPos in self._frange(tickMin, tickMax, step): + if logMin <= logPos <= logMax: + dataPos = 10 ** logPos + xPixel = x0 + (logPos - logMin) * xScale + yPixel = y0 + (logPos - logMin) * yScale + text = '1e%+03d' % logPos + yield ((xPixel, yPixel), dataPos, text) + + if step == 1: + ticks = list(self._frange(tickMin, tickMax, step))[:-1] + for logPos in ticks: + dataOrigPos = 10 ** logPos + for index in range(2, 10): + dataPos = dataOrigPos * index + if dataMin <= dataPos <= dataMax: + logSubPos = math.log10(dataPos) + xPixel = x0 + (logSubPos - logMin) * xScale + yPixel = y0 + (logSubPos - logMin) * yScale + yield ((xPixel, yPixel), dataPos, None) + + else: + xScale = (x1 - x0) / (dataMax - dataMin) + yScale = (y1 - y0) / (dataMax - dataMin) + + nbPixels = math.sqrt(pow(x1 - x0, 2) + pow(y1 - y0, 2)) + + # Density of 1.3 label per 92 pixels + # i.e., 1.3 label per inch on a 92 dpi screen + tickMin, tickMax, step, nbFrac = niceNumbersAdaptative( + dataMin, dataMax, nbPixels, 1.3 / 92) + + for dataPos in self._frange(tickMin, tickMax, step): + if dataMin <= dataPos <= dataMax: + xPixel = x0 + (dataPos - dataMin) * xScale + yPixel = y0 + (dataPos - dataMin) * yScale + + if nbFrac == 0: + text = '%g' % dataPos + else: + text = ('%.' + str(nbFrac) + 'f') % dataPos + yield ((xPixel, yPixel), dataPos, text) + + +# GLPlotFrame ################################################################# + +class GLPlotFrame(object): + """Base class for rendering a 2D frame surrounded by axes.""" + + _TICK_LENGTH_IN_PIXELS = 5 + _LINE_WIDTH = 1 + + _SHADERS = { + 'vertex': """ + attribute vec2 position; + uniform mat4 matrix; + + void main(void) { + gl_Position = matrix * vec4(position, 0.0, 1.0); + } + """, + 'fragment': """ + uniform vec4 color; + uniform float tickFactor; /* = 1./tickLength or 0. for solid line */ + + void main(void) { + if (mod(tickFactor * (gl_FragCoord.x + gl_FragCoord.y), 2.) < 1.) { + gl_FragColor = color; + } else { + discard; + } + } + """ + } + + _Margins = namedtuple('Margins', ('left', 'right', 'top', 'bottom')) + + def __init__(self, margins): + """ + :param margins: The margins around plot area for axis and labels. + :type margins: dict with 'left', 'right', 'top', 'bottom' keys and + values as ints. + """ + self._renderResources = None + + self._margins = self._Margins(**margins) + + self.axes = [] # List of PlotAxis to be updated by subclasses + + self._grid = False + self._size = 0., 0. + self._title = '' + + @property + def isDirty(self): + """True if it need to refresh graphic rendering, False otherwise.""" + return self._renderResources is None + + GRID_NONE = 0 + GRID_MAIN_TICKS = 1 + GRID_SUB_TICKS = 2 + GRID_ALL_TICKS = (GRID_MAIN_TICKS + GRID_SUB_TICKS) + + @property + def margins(self): + """Margins in pixels around the plot.""" + return self._margins + + @property + def grid(self): + """Grid display mode: + - 0: No grid. + - 1: Grid on main ticks. + - 2: Grid on sub-ticks for log scale axes. + - 3: Grid on main and sub ticks.""" + return self._grid + + @grid.setter + def grid(self, grid): + assert grid in (self.GRID_NONE, self.GRID_MAIN_TICKS, + self.GRID_SUB_TICKS, self.GRID_ALL_TICKS) + if grid != self._grid: + self._grid = grid + self._dirty() + + @property + def size(self): + """Size in pixels of the plot area including margins.""" + return self._size + + @size.setter + def size(self, size): + assert len(size) == 2 + size = tuple(size) + if size != self._size: + self._size = size + self._dirty() + + @property + def plotOrigin(self): + """Plot area origin (left, top) in widget coordinates in pixels.""" + return self.margins.left, self.margins.top + + @property + def plotSize(self): + """Plot area size (width, height) in pixels.""" + w, h = self.size + w -= self.margins.left + self.margins.right + h -= self.margins.top + self.margins.bottom + return w, h + + @property + def title(self): + """Main title as a str in latin-1.""" + return self._title + + @title.setter + def title(self, title): + if title != self._title: + self._title = title + self._dirty() + + # In-place update + # if self._renderResources is not None: + # self._renderResources[-1][-1].text = title + + def _dirty(self): + # When Text2D require discard we need to handle it + self._renderResources = None + + def _buildGridVertices(self): + if self._grid == self.GRID_NONE: + return [] + + elif self._grid == self.GRID_MAIN_TICKS: + def test(text): + return text is not None + elif self._grid == self.GRID_SUB_TICKS: + def test(text): + return text is None + elif self._grid == self.GRID_ALL_TICKS: + def test(_): + return True + else: + logging.warning('Wrong grid mode: %d' % self._grid) + return [] + + return self._buildGridVerticesWithTest(test) + + def _buildGridVerticesWithTest(self, test): + """Override in subclass to generate grid vertices""" + return [] + + def _buildVerticesAndLabels(self): + # To fill with copy of axes lists + vertices = [] + labels = [] + + for axis in self.axes: + axisVertices, axisLabels = axis.getVerticesAndLabels() + vertices += axisVertices + labels += axisLabels + + vertices = numpy.array(vertices, dtype=numpy.float32) + + # Add main title + xTitle = (self.size[0] + self.margins.left - + self.margins.right) // 2 + yTitle = self.margins.top - self._TICK_LENGTH_IN_PIXELS + labels.append(Text2D(text=self.title, + x=xTitle, + y=yTitle, + align=CENTER, + valign=BOTTOM)) + + # grid + gridVertices = numpy.array(self._buildGridVertices(), + dtype=numpy.float32) + + self._renderResources = (vertices, gridVertices, labels) + + _program = Program( + _SHADERS['vertex'], _SHADERS['fragment'], attrib0='position') + + def render(self): + if self._renderResources is None: + self._buildVerticesAndLabels() + vertices, gridVertices, labels = self._renderResources + + width, height = self.size + matProj = mat4Ortho(0, width, height, 0, 1, -1) + + gl.glViewport(0, 0, width, height) + + prog = self._program + prog.use() + + gl.glLineWidth(self._LINE_WIDTH) + + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matProj) + gl.glUniform4f(prog.uniforms['color'], 0., 0., 0., 1.) + gl.glUniform1f(prog.uniforms['tickFactor'], 0.) + + gl.glEnableVertexAttribArray(prog.attributes['position']) + gl.glVertexAttribPointer(prog.attributes['position'], + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + 0, vertices) + + gl.glDrawArrays(gl.GL_LINES, 0, len(vertices)) + + for label in labels: + label.render(matProj) + + def renderGrid(self): + if self._grid == self.GRID_NONE: + return + + if self._renderResources is None: + self._buildVerticesAndLabels() + vertices, gridVertices, labels = self._renderResources + + width, height = self.size + matProj = mat4Ortho(0, width, height, 0, 1, -1) + + gl.glViewport(0, 0, width, height) + + prog = self._program + prog.use() + + gl.glLineWidth(self._LINE_WIDTH) + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matProj) + gl.glUniform4f(prog.uniforms['color'], 0.7, 0.7, 0.7, 1.) + gl.glUniform1f(prog.uniforms['tickFactor'], 0.) # 1/2.) # 1/tickLen + + gl.glEnableVertexAttribArray(prog.attributes['position']) + gl.glVertexAttribPointer(prog.attributes['position'], + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + 0, gridVertices) + + gl.glDrawArrays(gl.GL_LINES, 0, len(gridVertices)) + + +# GLPlotFrame2D ############################################################### + +class GLPlotFrame2D(GLPlotFrame): + def __init__(self, margins): + """ + :param margins: The margins around plot area for axis and labels. + :type margins: dict with 'left', 'right', 'top', 'bottom' keys and + values as ints. + """ + super(GLPlotFrame2D, self).__init__(margins) + self.axes.append(PlotAxis(self, + tickLength=(0., -5.), + labelAlign=CENTER, labelVAlign=TOP, + titleAlign=CENTER, titleVAlign=TOP, + titleRotate=0, + titleOffset=(0, self.margins.bottom // 2))) + + self._x2AxisCoords = () + + self.axes.append(PlotAxis(self, + tickLength=(5., 0.), + labelAlign=RIGHT, labelVAlign=CENTER, + titleAlign=CENTER, titleVAlign=BOTTOM, + titleRotate=ROTATE_270, + titleOffset=(-3 * self.margins.left // 4, + 0))) + + self._y2Axis = PlotAxis(self, + tickLength=(-5., 0.), + labelAlign=LEFT, labelVAlign=CENTER, + titleAlign=CENTER, titleVAlign=TOP, + titleRotate=ROTATE_270, + titleOffset=(3 * self.margins.right // 4, + 0)) + + self._isYAxisInverted = False + + self._dataRanges = { + 'x': (1., 100.), 'y': (1., 100.), 'y2': (1., 100.)} + + self._baseVectors = (1., 0.), (0., 1.) + + self._transformedDataRanges = None + self._transformedDataProjMat = None + self._transformedDataY2ProjMat = None + + def _dirty(self): + super(GLPlotFrame2D, self)._dirty() + self._transformedDataRanges = None + self._transformedDataProjMat = None + self._transformedDataY2ProjMat = None + + @property + def isDirty(self): + """True if it need to refresh graphic rendering, False otherwise.""" + return (super(GLPlotFrame2D, self).isDirty or + self._transformedDataRanges is None or + self._transformedDataProjMat is None or + self._transformedDataY2ProjMat is None) + + @property + def xAxis(self): + return self.axes[0] + + @property + def yAxis(self): + return self.axes[1] + + @property + def y2Axis(self): + return self._y2Axis + + @property + def isY2Axis(self): + """Whether to display the left Y axis or not.""" + return len(self.axes) == 3 + + @isY2Axis.setter + def isY2Axis(self, isY2Axis): + if isY2Axis != self.isY2Axis: + if isY2Axis: + self.axes.append(self._y2Axis) + else: + self.axes = self.axes[:2] + + self._dirty() + + @property + def isYAxisInverted(self): + """Whether Y axes are inverted or not as a bool.""" + return self._isYAxisInverted + + @isYAxisInverted.setter + def isYAxisInverted(self, value): + value = bool(value) + if value != self._isYAxisInverted: + self._isYAxisInverted = value + self._dirty() + + DEFAULT_BASE_VECTORS = (1., 0.), (0., 1.) + """Values of baseVectors for orthogonal axes.""" + + @property + def baseVectors(self): + """Coordinates of the X and Y axes in the orthogonal plot coords. + + Raises ValueError if corresponding matrix is singular. + + 2 tuples of 2 floats: (xx, xy), (yx, yy) + """ + return self._baseVectors + + @baseVectors.setter + def baseVectors(self, baseVectors): + self._dirty() + + (xx, xy), (yx, yy) = baseVectors + vectors = (float(xx), float(xy)), (float(yx), float(yy)) + + det = (vectors[0][0] * vectors[1][1] - vectors[1][0] * vectors[0][1]) + if det == 0.: + raise ValueError("Singular matrix for base vectors: " + + str(vectors)) + + if vectors != self._baseVectors: + self._baseVectors = vectors + self._dirty() + + @property + def dataRanges(self): + """Ranges of data visible in the plot on x, y and y2 axes. + + This is different to the axes range when axes are not orthogonal. + + Type: ((xMin, xMax), (yMin, yMax), (y2Min, y2Max)) + """ + return self._DataRanges(self._dataRanges['x'], + self._dataRanges['y'], + self._dataRanges['y2']) + + @staticmethod + def _clipToSafeRange(min_, max_, isLog): + # Clip range if needed + minLimit = FLOAT32_MINPOS if isLog else FLOAT32_SAFE_MIN + min_ = numpy.clip(min_, minLimit, FLOAT32_SAFE_MAX) + max_ = numpy.clip(max_, minLimit, FLOAT32_SAFE_MAX) + assert min_ < max_ + return min_, max_ + + def setDataRanges(self, x=None, y=None, y2=None): + """Set data range over each axes. + + The provided ranges are clipped to possible values + (i.e., 32 float range + positive range for log scale). + + :param x: (min, max) data range over X axis + :param y: (min, max) data range over Y axis + :param y2: (min, max) data range over Y2 axis + """ + if x is not None: + self._dataRanges['x'] = \ + self._clipToSafeRange(x[0], x[1], self.xAxis.isLog) + + if y is not None: + self._dataRanges['y'] = \ + self._clipToSafeRange(y[0], y[1], self.yAxis.isLog) + + if y2 is not None: + self._dataRanges['y2'] = \ + self._clipToSafeRange(y2[0], y2[1], self.y2Axis.isLog) + + self.xAxis.dataRange = self._dataRanges['x'] + self.yAxis.dataRange = self._dataRanges['y'] + self.y2Axis.dataRange = self._dataRanges['y2'] + + _DataRanges = namedtuple('dataRanges', ('x', 'y', 'y2')) + + @property + def transformedDataRanges(self): + """Bounds of the displayed area in transformed data coordinates + (i.e., log scale applied if any as well as skew) + + 3-tuple of 2-tuple (min, max) for each axis: x, y, y2. + """ + if self._transformedDataRanges is None: + (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = self.dataRanges + + if self.xAxis.isLog: + try: + xMin = math.log10(xMin) + except ValueError: + _logger.info('xMin: warning log10(%f)', xMin) + xMin = 0. + try: + xMax = math.log10(xMax) + except ValueError: + _logger.info('xMax: warning log10(%f)', xMax) + xMax = 0. + + if self.yAxis.isLog: + try: + yMin = math.log10(yMin) + except ValueError: + _logger.info('yMin: warning log10(%f)', yMin) + yMin = 0. + try: + yMax = math.log10(yMax) + except ValueError: + _logger.info('yMax: warning log10(%f)', yMax) + yMax = 0. + + try: + y2Min = math.log10(y2Min) + except ValueError: + _logger.info('yMin: warning log10(%f)', y2Min) + y2Min = 0. + try: + y2Max = math.log10(y2Max) + except ValueError: + _logger.info('yMax: warning log10(%f)', y2Max) + y2Max = 0. + + # Non-orthogonal axes + if self.baseVectors != self.DEFAULT_BASE_VECTORS: + (xx, xy), (yx, yy) = self.baseVectors + skew_mat = numpy.array(((xx, yx), (xy, yy))) + + corners = [(xMin, yMin), (xMin, yMax), + (xMax, yMin), (xMax, yMax), + (xMin, y2Min), (xMin, y2Max), + (xMax, y2Min), (xMax, y2Max)] + + corners = numpy.array( + [numpy.dot(skew_mat, corner) for corner in corners], + dtype=numpy.float32) + xMin, xMax = corners[:, 0].min(), corners[:, 0].max() + yMin, yMax = corners[0:4, 1].min(), corners[0:4, 1].max() + y2Min, y2Max = corners[4:, 1].min(), corners[4:, 1].max() + + self._transformedDataRanges = self._DataRanges( + (xMin, xMax), (yMin, yMax), (y2Min, y2Max)) + + return self._transformedDataRanges + + @property + def transformedDataProjMat(self): + """Orthographic projection matrix for rendering transformed data + + :type: numpy.matrix + """ + if self._transformedDataProjMat is None: + xMin, xMax = self.transformedDataRanges.x + yMin, yMax = self.transformedDataRanges.y + + if self.isYAxisInverted: + mat = mat4Ortho(xMin, xMax, yMax, yMin, 1, -1) + else: + mat = mat4Ortho(xMin, xMax, yMin, yMax, 1, -1) + + # Non-orthogonal axes + if self.baseVectors != self.DEFAULT_BASE_VECTORS: + (xx, xy), (yx, yy) = self.baseVectors + mat = mat * numpy.matrix(( + (xx, yx, 0., 0.), + (xy, yy, 0., 0.), + (0., 0., 1., 0.), + (0., 0., 0., 1.)), dtype=numpy.float32) + + self._transformedDataProjMat = mat + + return self._transformedDataProjMat + + @property + def transformedDataY2ProjMat(self): + """Orthographic projection matrix for rendering transformed data + for the 2nd Y axis + + :type: numpy.matrix + """ + if self._transformedDataY2ProjMat is None: + xMin, xMax = self.transformedDataRanges.x + y2Min, y2Max = self.transformedDataRanges.y2 + + if self.isYAxisInverted: + mat = mat4Ortho(xMin, xMax, y2Max, y2Min, 1, -1) + else: + mat = mat4Ortho(xMin, xMax, y2Min, y2Max, 1, -1) + + # Non-orthogonal axes + if self.baseVectors != self.DEFAULT_BASE_VECTORS: + (xx, xy), (yx, yy) = self.baseVectors + mat = mat * numpy.matrix(( + (xx, yx, 0., 0.), + (xy, yy, 0., 0.), + (0., 0., 1., 0.), + (0., 0., 0., 1.)), dtype=numpy.float32) + + self._transformedDataY2ProjMat = mat + + return self._transformedDataY2ProjMat + + def dataToPixel(self, x, y, axis='left'): + """Convert data coordinate to widget pixel coordinate. + """ + assert axis in ('left', 'right') + + trBounds = self.transformedDataRanges + + if self.xAxis.isLog: + if x < FLOAT32_MINPOS: + return None + xDataTr = math.log10(x) + else: + xDataTr = x + + if self.yAxis.isLog: + if y < FLOAT32_MINPOS: + return None + yDataTr = math.log10(y) + else: + yDataTr = y + + # Non-orthogonal axes + if self.baseVectors != self.DEFAULT_BASE_VECTORS: + (xx, xy), (yx, yy) = self.baseVectors + skew_mat = numpy.array(((xx, yx), (xy, yy))) + + coords = numpy.dot(skew_mat, numpy.array((xDataTr, yDataTr))) + xDataTr, yDataTr = coords + + plotWidth, plotHeight = self.plotSize + + xPixel = int(self.margins.left + + plotWidth * (xDataTr - trBounds.x[0]) / + (trBounds.x[1] - trBounds.x[0])) + + usedAxis = trBounds.y if axis == "left" else trBounds.y2 + yOffset = (plotHeight * (yDataTr - usedAxis[0]) / + (usedAxis[1] - usedAxis[0])) + + if self.isYAxisInverted: + yPixel = int(self.margins.top + yOffset) + else: + yPixel = int(self.size[1] - self.margins.bottom - yOffset) + + return xPixel, yPixel + + def pixelToData(self, x, y, axis="left"): + """Convert pixel position to data coordinates. + + :param float x: X coord + :param float y: Y coord + :param str axis: Y axis to use in ('left', 'right') + :return: (x, y) position in data coords + """ + assert axis in ("left", "right") + + plotWidth, plotHeight = self.plotSize + + trBounds = self.transformedDataRanges + + xData = (x - self.margins.left + 0.5) / float(plotWidth) + xData = trBounds.x[0] + xData * (trBounds.x[1] - trBounds.x[0]) + + usedAxis = trBounds.y if axis == "left" else trBounds.y2 + if self.isYAxisInverted: + yData = (y - self.margins.top + 0.5) / float(plotHeight) + yData = usedAxis[0] + yData * (usedAxis[1] - usedAxis[0]) + else: + yData = self.size[1] - self.margins.bottom - y - 0.5 + yData /= float(plotHeight) + yData = usedAxis[0] + yData * (usedAxis[1] - usedAxis[0]) + + # non-orthogonal axis + if self.baseVectors != self.DEFAULT_BASE_VECTORS: + (xx, xy), (yx, yy) = self.baseVectors + skew_mat = numpy.array(((xx, yx), (xy, yy))) + skew_mat = numpy.linalg.inv(skew_mat) + + coords = numpy.dot(skew_mat, numpy.array((xData, yData))) + xData, yData = coords + + if self.xAxis.isLog: + xData = pow(10, xData) + if self.yAxis.isLog: + yData = pow(10, yData) + + return xData, yData + + def _buildGridVerticesWithTest(self, test): + vertices = [] + + if self.baseVectors == self.DEFAULT_BASE_VECTORS: + for axis in self.axes: + for (xPixel, yPixel), data, text in axis.ticks: + if test(text): + vertices.append((xPixel, yPixel)) + if axis == self.xAxis: + vertices.append((xPixel, self.margins.top)) + elif axis == self.yAxis: + vertices.append((self.size[0] - self.margins.right, + yPixel)) + else: # axis == self.y2Axis + vertices.append((self.margins.left, yPixel)) + + else: + # Get plot corners in data coords + plotLeft, plotTop = self.plotOrigin + plotWidth, plotHeight = self.plotSize + + corners = [(plotLeft, plotTop), + (plotLeft, plotTop + plotHeight), + (plotLeft + plotWidth, plotTop + plotHeight), + (plotLeft + plotWidth, plotTop)] + + for axis in self.axes: + if axis == self.xAxis: + cornersInData = numpy.array([ + self.pixelToData(x, y) for (x, y) in corners]) + borders = ((cornersInData[0], cornersInData[3]), # top + (cornersInData[1], cornersInData[0]), # left + (cornersInData[3], cornersInData[2])) # right + + for (xPixel, yPixel), data, text in axis.ticks: + if test(text): + for (x0, y0), (x1, y1) in borders: + if min(x0, x1) <= data < max(x0, x1): + yIntersect = (data - x0) * \ + (y1 - y0) / (x1 - x0) + y0 + + pixelPos = self.dataToPixel( + data, yIntersect) + if pixelPos is not None: + vertices.append((xPixel, yPixel)) + vertices.append(pixelPos) + break # Stop at first intersection + + else: # y or y2 axes + if axis == self.yAxis: + axis_name = 'left' + cornersInData = numpy.array([ + self.pixelToData(x, y) for (x, y) in corners]) + borders = ( + (cornersInData[3], cornersInData[2]), # right + (cornersInData[0], cornersInData[3]), # top + (cornersInData[2], cornersInData[1])) # bottom + + else: # axis == self.y2Axis + axis_name = 'right' + corners = numpy.array([self.pixelToData( + x, y, axis='right') for (x, y) in corners]) + borders = ( + (cornersInData[1], cornersInData[0]), # left + (cornersInData[0], cornersInData[3]), # top + (cornersInData[2], cornersInData[1])) # bottom + + for (xPixel, yPixel), data, text in axis.ticks: + if test(text): + for (x0, y0), (x1, y1) in borders: + if min(y0, y1) <= data < max(y0, y1): + xIntersect = (data - y0) * \ + (x1 - x0) / (y1 - y0) + x0 + + pixelPos = self.dataToPixel( + xIntersect, data, axis=axis_name) + if pixelPos is not None: + vertices.append((xPixel, yPixel)) + vertices.append(pixelPos) + break # Stop at first intersection + + return vertices + + def _buildVerticesAndLabels(self): + width, height = self.size + + xCoords = (self.margins.left - 0.5, + width - self.margins.right + 0.5) + yCoords = (height - self.margins.bottom + 0.5, + self.margins.top - 0.5) + + self.axes[0].displayCoords = ((xCoords[0], yCoords[0]), + (xCoords[1], yCoords[0])) + + self._x2AxisCoords = ((xCoords[0], yCoords[1]), + (xCoords[1], yCoords[1])) + + if self.isYAxisInverted: + # Y axes are inverted, axes coordinates are inverted + yCoords = yCoords[1], yCoords[0] + + self.axes[1].displayCoords = ((xCoords[0], yCoords[0]), + (xCoords[0], yCoords[1])) + + self._y2Axis.displayCoords = ((xCoords[1], yCoords[0]), + (xCoords[1], yCoords[1])) + + super(GLPlotFrame2D, self)._buildVerticesAndLabels() + + vertices, gridVertices, labels = self._renderResources + + # Adds vertices for borders without axis + extraVertices = [] + extraVertices += self._x2AxisCoords + if not self.isY2Axis: + extraVertices += self._y2Axis.displayCoords + + extraVertices = numpy.array( + extraVertices, copy=False, dtype=numpy.float32) + vertices = numpy.append(vertices, extraVertices, axis=0) + + self._renderResources = (vertices, gridVertices, labels) diff --git a/silx/gui/plot/backends/glutils/GLPlotImage.py b/silx/gui/plot/backends/glutils/GLPlotImage.py new file mode 100644 index 0000000..8fff82b --- /dev/null +++ b/silx/gui/plot/backends/glutils/GLPlotImage.py @@ -0,0 +1,707 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +""" +This module provides a class to render 2D array as a colormap or RGB(A) image +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "03/04/2017" + + +import math +import numpy + +from silx.math.combo import min_max + +from ...._glutils import gl, Program, Texture +from ..._utils import FLOAT32_MINPOS +from .GLSupport import mat4Translate, mat4Scale +from .GLTexture import Image + + +class _GLPlotData2D(object): + def __init__(self, data, origin, scale): + self.data = data + assert len(origin) == 2 + self.origin = tuple(origin) + assert len(scale) == 2 + self.scale = tuple(scale) + + def pick(self, x, y): + if self.xMin <= x <= self.xMax and self.yMin <= y <= self.yMax: + ox, oy = self.origin + sx, sy = self.scale + col = int((x - ox) / sx) + row = int((y - oy) / sy) + return col, row + else: + return None + + @property + def xMin(self): + ox, sx = self.origin[0], self.scale[0] + return ox if sx >= 0. else ox + sx * self.data.shape[1] + + @property + def yMin(self): + oy, sy = self.origin[1], self.scale[1] + return oy if sy >= 0. else oy + sy * self.data.shape[0] + + @property + def xMax(self): + ox, sx = self.origin[0], self.scale[0] + return ox + sx * self.data.shape[1] if sx >= 0. else ox + + @property + def yMax(self): + oy, sy = self.origin[1], self.scale[1] + return oy + sy * self.data.shape[0] if sy >= 0. else oy + + def discard(self): + pass + + def prepare(self): + pass + + def render(self, matrix, isXLog, isYLog): + pass + + +class GLPlotColormap(_GLPlotData2D): + + _SHADERS = { + 'linear': { + 'vertex': """ + #version 120 + + uniform mat4 matrix; + attribute vec2 texCoords; + attribute vec2 position; + + varying vec2 coords; + + void main(void) { + coords = texCoords; + gl_Position = matrix * vec4(position, 0.0, 1.0); + } + """, + 'fragTransform': """ + vec2 textureCoords(void) { + return coords; + } + """}, + + 'log': { + 'vertex': """ + #version 120 + + attribute vec2 position; + uniform mat4 matrix; + uniform mat4 matOffset; + uniform bvec2 isLog; + + varying vec2 coords; + + const float oneOverLog10 = 0.43429448190325176; + + void main(void) { + vec4 dataPos = matOffset * vec4(position, 0.0, 1.0); + if (isLog.x) { + dataPos.x = oneOverLog10 * log(dataPos.x); + } + if (isLog.y) { + dataPos.y = oneOverLog10 * log(dataPos.y); + } + coords = dataPos.xy; + gl_Position = matrix * dataPos; + } + """, + 'fragTransform': """ + uniform bvec2 isLog; + uniform struct { + vec2 oneOverRange; + vec2 originOverRange; + } bounds; + + vec2 textureCoords(void) { + vec2 pos = coords; + if (isLog.x) { + pos.x = pow(10., coords.x); + } + if (isLog.y) { + pos.y = pow(10., coords.y); + } + return pos * bounds.oneOverRange - bounds.originOverRange; + // TODO texture coords in range different from [0, 1] + } + """}, + + 'fragment': """ + #version 120 + + uniform sampler2D data; + uniform struct { + sampler2D texture; + bool isLog; + float min; + float oneOverRange; + } cmap; + uniform float alpha; + + varying vec2 coords; + + %s + + const float oneOverLog10 = 0.43429448190325176; + + void main(void) { + float value = texture2D(data, textureCoords()).r; + if (cmap.isLog) { + if (value > 0.) { + value = clamp(cmap.oneOverRange * + (oneOverLog10 * log(value) - cmap.min), + 0., 1.); + } else { + value = 0.; + } + } else { /*Linear mapping*/ + value = clamp(cmap.oneOverRange * (value - cmap.min), 0., 1.); + } + + gl_FragColor = texture2D(cmap.texture, vec2(value, 0.5)); + gl_FragColor.a *= alpha; + } + """ + } + + _DATA_TEX_UNIT = 0 + _CMAP_TEX_UNIT = 1 + + _INTERNAL_FORMATS = { + numpy.dtype(numpy.float32): gl.GL_R32F, + # Use normalized integer for unsigned int formats + numpy.dtype(numpy.uint16): gl.GL_R16, + numpy.dtype(numpy.uint8): gl.GL_R8, + } + + _linearProgram = Program(_SHADERS['linear']['vertex'], + _SHADERS['fragment'] % + _SHADERS['linear']['fragTransform'], + attrib0='position') + + _logProgram = Program(_SHADERS['log']['vertex'], + _SHADERS['fragment'] % + _SHADERS['log']['fragTransform'], + attrib0='position') + + def __init__(self, data, origin, scale, + colormap, cmapIsLog=False, cmapRange=None, + alpha=1.0): + """Create a 2D colormap + + :param data: The 2D scalar data array to display + :type data: numpy.ndarray with 2 dimensions (dtype=numpy.float32) + :param origin: (x, y) coordinates of the origin of the data array + :type origin: 2-tuple of floats. + :param scale: (sx, sy) scale factors of the data array. + This is the size of a data pixel in plot data space. + :type scale: 2-tuple of floats. + :param str colormap: Name of the colormap to use + TODO: Accept a 1D scalar array as the colormap + :param bool cmapIsLog: If True, uses log10 of the data value + :param cmapRange: The range of colormap or None for autoscale colormap + For logarithmic colormap, the range is in the untransformed data + TODO: check consistency with matplotlib + :type cmapRange: (float, float) or None + :param float alpha: Opacity from 0 (transparent) to 1 (opaque) + """ + assert data.dtype in self._INTERNAL_FORMATS + + super(GLPlotColormap, self).__init__(data, origin, scale) + self.colormap = numpy.array(colormap, copy=False) + self.cmapIsLog = cmapIsLog + self._cmapRange = None # User-provided range info + self._cmapRangeCache = None # Store extra data for range + self.cmapRange = cmapRange # Update _cmapRange + self._alpha = numpy.clip(alpha, 0., 1.) + + self._cmap_texture = None + self._texture = None + self._textureIsDirty = False + + def discard(self): + if self._cmap_texture is not None: + self._cmap_texture.discard() + self._cmap_texture = None + + if self._texture is not None: + self._texture.discard() + self._texture = None + self._textureIsDirty = False + + @property + def cmapRange(self): + if self._cmapRange is None: # Auto-scale mode + if self._cmapRangeCache is None: + # Build data , positive ranges + result = min_max(self.data, min_positive=True) + min_ = result.minimum + minPos = result.min_positive + max_ = result.maximum + maxPos = max_ if max_ > 0. else 1. + if minPos is None: + minPos = maxPos + self._cmapRangeCache = {'range': (min_, max_), + 'pos': (minPos, maxPos)} + + return self._cmapRangeCache['pos' if self.cmapIsLog else 'range'] + + else: + if not self.cmapIsLog: + return self._cmapRange # Return range as is + else: + if self._cmapRangeCache is None: + # Build a strictly positive range from cmapRange + min_, max_ = self._cmapRange + if min_ > 0. and max_ > 0.: + minPos, maxPos = min_, max_ + else: + result = min_max(self.data, min_positive=True) + minPos = result.min_positive + dataMax = result.maximum + if max_ > 0.: + maxPos = max_ + elif dataMax > 0.: + maxPos = dataMax + else: + maxPos = 1. # Arbitrary fallback + if minPos is None: + minPos = maxPos + self._cmapRangeCache = minPos, maxPos + return self._cmapRangeCache # Strictly positive range + + @cmapRange.setter + def cmapRange(self, cmapRange): + self._cmapRangeCache = None + if cmapRange is None: + self._cmapRange = None + else: + assert len(cmapRange) == 2 + assert cmapRange[0] <= cmapRange[1] + self._cmapRange = tuple(cmapRange) + + @property + def alpha(self): + return self._alpha + + def updateData(self, data): + assert data.dtype in self._INTERNAL_FORMATS + oldData = self.data + self.data = data + + self._cmapRangeCache = None + + if self._texture is not None: + if (self.data.shape != oldData.shape or + self.data.dtype != oldData.dtype): + self.discard() + else: + self._textureIsDirty = True + + def prepare(self): + if self._cmap_texture is None: + # TODO share cmap texture accross Images + # put all cmaps in one texture + colormap = numpy.empty((16, 256, self.colormap.shape[1]), + dtype=self.colormap.dtype) + colormap[:] = self.colormap + format_ = gl.GL_RGBA if colormap.shape[-1] == 4 else gl.GL_RGB + self._cmap_texture = Texture(internalFormat=format_, + data=colormap, + format_=format_, + texUnit=self._CMAP_TEX_UNIT, + minFilter=gl.GL_NEAREST, + magFilter=gl.GL_NEAREST, + wrap=(gl.GL_CLAMP_TO_EDGE, + gl.GL_CLAMP_TO_EDGE)) + + if self._texture is None: + internalFormat = self._INTERNAL_FORMATS[self.data.dtype] + + self._texture = Image(internalFormat, + self.data, + format_=gl.GL_RED, + texUnit=self._DATA_TEX_UNIT) + elif self._textureIsDirty: + self._textureIsDirty = True + self._texture.updateAll(format_=gl.GL_RED, data=self.data) + + def _setCMap(self, prog): + dataMin, dataMax = self.cmapRange # If log, it is stricly positive + + if self.data.dtype in (numpy.uint16, numpy.uint8): + # Using unsigned int as normalized integer in OpenGL + # So normalize range + maxInt = float(numpy.iinfo(self.data.dtype).max) + dataMin, dataMax = dataMin / maxInt, dataMax / maxInt + + if self.cmapIsLog: + dataMin = math.log10(dataMin) + dataMax = math.log10(dataMax) + + gl.glUniform1i(prog.uniforms['cmap.texture'], + self._cmap_texture.texUnit) + gl.glUniform1i(prog.uniforms['cmap.isLog'], self.cmapIsLog) + gl.glUniform1f(prog.uniforms['cmap.min'], dataMin) + if dataMax > dataMin: + oneOverRange = 1. / (dataMax - dataMin) + else: + oneOverRange = 0. # Fall-back + gl.glUniform1f(prog.uniforms['cmap.oneOverRange'], oneOverRange) + + self._cmap_texture.bind() + + def _renderLinear(self, matrix): + self.prepare() + + prog = self._linearProgram + prog.use() + + gl.glUniform1i(prog.uniforms['data'], self._DATA_TEX_UNIT) + + mat = matrix * mat4Translate(*self.origin) * mat4Scale(*self.scale) + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, mat) + + gl.glUniform1f(prog.uniforms['alpha'], self.alpha) + + self._setCMap(prog) + + self._texture.render(prog.attributes['position'], + prog.attributes['texCoords'], + self._DATA_TEX_UNIT) + + def _renderLog10(self, matrix, isXLog, isYLog): + xMin, yMin = self.xMin, self.yMin + if ((isXLog and xMin < FLOAT32_MINPOS) or + (isYLog and yMin < FLOAT32_MINPOS)): + # Do not render images that are partly or totally <= 0 + return + + self.prepare() + + prog = self._logProgram + prog.use() + + ox, oy = self.origin + + gl.glUniform1i(prog.uniforms['data'], self._DATA_TEX_UNIT) + + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matrix) + mat = mat4Translate(ox, oy) * mat4Scale(*self.scale) + gl.glUniformMatrix4fv(prog.uniforms['matOffset'], 1, gl.GL_TRUE, mat) + + gl.glUniform2i(prog.uniforms['isLog'], isXLog, isYLog) + + ex = ox + self.scale[0] * self.data.shape[1] + ey = oy + self.scale[1] * self.data.shape[0] + + xOneOverRange = 1. / (ex - ox) + yOneOverRange = 1. / (ey - oy) + gl.glUniform2f(prog.uniforms['bounds.originOverRange'], + ox * xOneOverRange, oy * yOneOverRange) + gl.glUniform2f(prog.uniforms['bounds.oneOverRange'], + xOneOverRange, yOneOverRange) + + gl.glUniform1f(prog.uniforms['alpha'], self.alpha) + + self._setCMap(prog) + + try: + tiles = self._texture.tiles + except AttributeError: + raise RuntimeError("No texture, discard has already been called") + if len(tiles) > 1: + raise NotImplementedError( + "Image over multiple textures not supported with log scale") + + texture, vertices, info = tiles[0] + + texture.bind(self._DATA_TEX_UNIT) + + posAttrib = prog.attributes['position'] + stride = vertices.shape[-1] * vertices.itemsize + gl.glEnableVertexAttribArray(posAttrib) + gl.glVertexAttribPointer(posAttrib, + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + stride, vertices) + + gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(vertices)) + + def render(self, matrix, isXLog, isYLog): + if any((isXLog, isYLog)): + self._renderLog10(matrix, isXLog, isYLog) + else: + self._renderLinear(matrix) + + # Unbind colormap texture + gl.glActiveTexture(gl.GL_TEXTURE0 + self._cmap_texture.texUnit) + gl.glBindTexture(self._cmap_texture.target, 0) + + +# image ####################################################################### + +class GLPlotRGBAImage(_GLPlotData2D): + + _SHADERS = { + 'linear': { + 'vertex': """ + #version 120 + + attribute vec2 position; + attribute vec2 texCoords; + uniform mat4 matrix; + + varying vec2 coords; + + void main(void) { + gl_Position = matrix * vec4(position, 0.0, 1.0); + coords = texCoords; + } + """, + 'fragment': """ + #version 120 + + uniform sampler2D tex; + uniform float alpha; + + varying vec2 coords; + + void main(void) { + gl_FragColor = texture2D(tex, coords); + gl_FragColor.a *= alpha; + } + """}, + + 'log': { + 'vertex': """ + #version 120 + + attribute vec2 position; + uniform mat4 matrix; + uniform mat4 matOffset; + uniform bvec2 isLog; + + varying vec2 coords; + + const float oneOverLog10 = 0.43429448190325176; + + void main(void) { + vec4 dataPos = matOffset * vec4(position, 0.0, 1.0); + if (isLog.x) { + dataPos.x = oneOverLog10 * log(dataPos.x); + } + if (isLog.y) { + dataPos.y = oneOverLog10 * log(dataPos.y); + } + coords = dataPos.xy; + gl_Position = matrix * dataPos; + } + """, + 'fragment': """ + #version 120 + + uniform sampler2D tex; + uniform bvec2 isLog; + uniform struct { + vec2 oneOverRange; + vec2 originOverRange; + } bounds; + uniform float alpha; + + varying vec2 coords; + + vec2 textureCoords(void) { + vec2 pos = coords; + if (isLog.x) { + pos.x = pow(10., coords.x); + } + if (isLog.y) { + pos.y = pow(10., coords.y); + } + return pos * bounds.oneOverRange - bounds.originOverRange; + // TODO texture coords in range different from [0, 1] + } + + void main(void) { + gl_FragColor = texture2D(tex, textureCoords()); + gl_FragColor.a *= alpha; + } + """} + } + + _DATA_TEX_UNIT = 0 + + _SUPPORTED_DTYPES = (numpy.dtype(numpy.float32), + numpy.dtype(numpy.uint8)) + + _linearProgram = Program(_SHADERS['linear']['vertex'], + _SHADERS['linear']['fragment'], + attrib0='position') + + _logProgram = Program(_SHADERS['log']['vertex'], + _SHADERS['log']['fragment'], + attrib0='position') + + def __init__(self, data, origin, scale, alpha): + """Create a 2D RGB(A) image from data + + :param data: The 2D image data array to display + :type data: numpy.ndarray with 3 dimensions + (dtype=numpy.uint8 or numpy.float32) + :param origin: (x, y) coordinates of the origin of the data array + :type origin: 2-tuple of floats. + :param scale: (sx, sy) scale factors of the data array. + This is the size of a data pixel in plot data space. + :type scale: 2-tuple of floats. + :param float alpha: Opacity from 0 (transparent) to 1 (opaque) + """ + assert data.dtype in self._SUPPORTED_DTYPES + super(GLPlotRGBAImage, self).__init__(data, origin, scale) + self._texture = None + self._textureIsDirty = False + self._alpha = numpy.clip(alpha, 0., 1.) + + @property + def alpha(self): + return self._alpha + + def discard(self): + if self._texture is not None: + self._texture.discard() + self._texture = None + self._textureIsDirty = False + + def updateData(self, data): + assert data.dtype in self._SUPPORTED_DTYPES + oldData = self.data + self.data = data + + if self._texture is not None: + if self.data.shape != oldData.shape: + self.discard() + else: + self._textureIsDirty = True + + def prepare(self): + if self._texture is None: + format_ = gl.GL_RGBA if self.data.shape[2] == 4 else gl.GL_RGB + + self._texture = Image(format_, + self.data, + format_=format_, + texUnit=self._DATA_TEX_UNIT) + elif self._textureIsDirty: + self._textureIsDirty = False + + # We should check that internal format is the same + format_ = gl.GL_RGBA if self.data.shape[2] == 4 else gl.GL_RGB + self._texture.updateAll(format_=format_, data=self.data) + + def _renderLinear(self, matrix): + self.prepare() + + prog = self._linearProgram + prog.use() + + gl.glUniform1i(prog.uniforms['tex'], self._DATA_TEX_UNIT) + + mat = matrix * mat4Translate(*self.origin) * mat4Scale(*self.scale) + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, mat) + + gl.glUniform1f(prog.uniforms['alpha'], self.alpha) + + self._texture.render(prog.attributes['position'], + prog.attributes['texCoords'], + self._DATA_TEX_UNIT) + + def _renderLog(self, matrix, isXLog, isYLog): + self.prepare() + + prog = self._logProgram + prog.use() + + ox, oy = self.origin + + gl.glUniform1i(prog.uniforms['tex'], self._DATA_TEX_UNIT) + + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matrix) + mat = mat4Translate(ox, oy) * mat4Scale(*self.scale) + gl.glUniformMatrix4fv(prog.uniforms['matOffset'], 1, gl.GL_TRUE, mat) + + gl.glUniform2i(prog.uniforms['isLog'], isXLog, isYLog) + + gl.glUniform1f(prog.uniforms['alpha'], self.alpha) + + ex = ox + self.scale[0] * self.data.shape[1] + ey = oy + self.scale[1] * self.data.shape[0] + + xOneOverRange = 1. / (ex - ox) + yOneOverRange = 1. / (ey - oy) + gl.glUniform2f(prog.uniforms['bounds.originOverRange'], + ox * xOneOverRange, oy * yOneOverRange) + gl.glUniform2f(prog.uniforms['bounds.oneOverRange'], + xOneOverRange, yOneOverRange) + + try: + tiles = self._texture.tiles + except AttributeError: + raise RuntimeError("No texture, discard has already been called") + if len(tiles) > 1: + raise NotImplementedError( + "Image over multiple textures not supported with log scale") + + texture, vertices, info = tiles[0] + + texture.bind(self._DATA_TEX_UNIT) + + posAttrib = prog.attributes['position'] + stride = vertices.shape[-1] * vertices.itemsize + gl.glEnableVertexAttribArray(posAttrib) + gl.glVertexAttribPointer(posAttrib, + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + stride, vertices) + + gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(vertices)) + + def render(self, matrix, isXLog, isYLog): + if any((isXLog, isYLog)): + self._renderLog(matrix, isXLog, isYLog) + else: + self._renderLinear(matrix) diff --git a/silx/gui/plot/backends/glutils/GLSupport.py b/silx/gui/plot/backends/glutils/GLSupport.py new file mode 100644 index 0000000..3f473be --- /dev/null +++ b/silx/gui/plot/backends/glutils/GLSupport.py @@ -0,0 +1,192 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +""" +This module provides convenient classes and functions for OpenGL rendering. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "03/04/2017" + + +import numpy + +from ...._glutils import gl + + +def buildFillMaskIndices(nIndices): + if nIndices <= numpy.iinfo(numpy.uint16).max + 1: + dtype = numpy.uint16 + else: + dtype = numpy.uint32 + + lastIndex = nIndices - 1 + splitIndex = lastIndex // 2 + 1 + indices = numpy.empty(nIndices, dtype=dtype) + indices[::2] = numpy.arange(0, splitIndex, step=1, dtype=dtype) + indices[1::2] = numpy.arange(lastIndex, splitIndex - 1, step=-1, + dtype=dtype) + return indices + + +class Shape2D(object): + _NO_HATCH = 0 + _HATCH_STEP = 20 + + def __init__(self, points, fill='solid', stroke=True, + fillColor=(0., 0., 0., 1.), strokeColor=(0., 0., 0., 1.), + strokeClosed=True): + self.vertices = numpy.array(points, dtype=numpy.float32, copy=False) + self.strokeClosed = strokeClosed + + self._indices = buildFillMaskIndices(len(self.vertices)) + + tVertex = numpy.transpose(self.vertices) + xMin, xMax = min(tVertex[0]), max(tVertex[0]) + yMin, yMax = min(tVertex[1]), max(tVertex[1]) + self.bboxVertices = numpy.array(((xMin, yMin), (xMin, yMax), + (xMax, yMin), (xMax, yMax)), + dtype=numpy.float32) + self._xMin, self._xMax = xMin, xMax + self._yMin, self._yMax = yMin, yMax + + self.fill = fill + self.fillColor = fillColor + self.stroke = stroke + self.strokeColor = strokeColor + + @property + def xMin(self): + return self._xMin + + @property + def xMax(self): + return self._xMax + + @property + def yMin(self): + return self._yMin + + @property + def yMax(self): + return self._yMax + + def prepareFillMask(self, posAttrib): + gl.glEnableVertexAttribArray(posAttrib) + gl.glVertexAttribPointer(posAttrib, + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + 0, self.vertices) + + gl.glEnable(gl.GL_STENCIL_TEST) + gl.glStencilMask(1) + gl.glStencilFunc(gl.GL_ALWAYS, 1, 1) + gl.glStencilOp(gl.GL_INVERT, gl.GL_INVERT, gl.GL_INVERT) + gl.glColorMask(gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE) + gl.glDepthMask(gl.GL_FALSE) + + gl.glDrawElements(gl.GL_TRIANGLE_STRIP, len(self._indices), + gl.GL_UNSIGNED_SHORT, self._indices) + + gl.glStencilFunc(gl.GL_EQUAL, 1, 1) + # Reset stencil while drawing + gl.glStencilOp(gl.GL_ZERO, gl.GL_ZERO, gl.GL_ZERO) + gl.glColorMask(gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE) + gl.glDepthMask(gl.GL_TRUE) + + def renderFill(self, posAttrib): + self.prepareFillMask(posAttrib) + + gl.glVertexAttribPointer(posAttrib, + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + 0, self.bboxVertices) + gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self.bboxVertices)) + + gl.glDisable(gl.GL_STENCIL_TEST) + + def renderStroke(self, posAttrib): + gl.glEnableVertexAttribArray(posAttrib) + gl.glVertexAttribPointer(posAttrib, + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + 0, self.vertices) + gl.glLineWidth(1) + drawMode = gl.GL_LINE_LOOP if self.strokeClosed else gl.GL_LINE_STRIP + gl.glDrawArrays(drawMode, 0, len(self.vertices)) + + def render(self, posAttrib, colorUnif, hatchStepUnif): + assert self.fill in ['hatch', 'solid', None] + if self.fill is not None: + gl.glUniform4f(colorUnif, *self.fillColor) + step = self._HATCH_STEP if self.fill == 'hatch' else self._NO_HATCH + gl.glUniform1i(hatchStepUnif, step) + self.renderFill(posAttrib) + + if self.stroke: + gl.glUniform4f(colorUnif, *self.strokeColor) + gl.glUniform1i(hatchStepUnif, self._NO_HATCH) + self.renderStroke(posAttrib) + + +# matrix ###################################################################### + +def mat4Ortho(left, right, bottom, top, near, far): + """Orthographic projection matrix (row-major)""" + return numpy.matrix(( + (2./(right - left), 0., 0., -(right+left)/float(right-left)), + (0., 2./(top - bottom), 0., -(top+bottom)/float(top-bottom)), + (0., 0., -2./(far-near), -(far+near)/float(far-near)), + (0., 0., 0., 1.)), dtype=numpy.float32) + + +def mat4Translate(x=0., y=0., z=0.): + """Translation matrix (row-major)""" + return numpy.matrix(( + (1., 0., 0., x), + (0., 1., 0., y), + (0., 0., 1., z), + (0., 0., 0., 1.)), dtype=numpy.float32) + + +def mat4Scale(sx=1., sy=1., sz=1.): + """Scale matrix (row-major)""" + return numpy.matrix(( + (sx, 0., 0., 0.), + (0., sy, 0., 0.), + (0., 0., sz, 0.), + (0., 0., 0., 1.)), dtype=numpy.float32) + + +def mat4Identity(): + """Identity matrix""" + return numpy.matrix(( + (1., 0., 0., 0.), + (0., 1., 0., 0.), + (0., 0., 1., 0.), + (0., 0., 0., 1.)), dtype=numpy.float32) diff --git a/silx/gui/plot/backends/glutils/GLText.py b/silx/gui/plot/backends/glutils/GLText.py new file mode 100644 index 0000000..495882c --- /dev/null +++ b/silx/gui/plot/backends/glutils/GLText.py @@ -0,0 +1,222 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +""" +This module provides minimalistic text support for OpenGL. +It provides Latin-1 (ISO8859-1) characters for one monospace font at one size. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "03/04/2017" + + +import numpy + +from ...._glutils import font, gl, getGLContext, Program, Texture +from .GLSupport import mat4Translate + + +# TODO: Font should be configurable by the main program: using mpl.rcParams? + + +# Text2D ###################################################################### + +LEFT, CENTER, RIGHT = 'left', 'center', 'right' +TOP, BASELINE, BOTTOM = 'top', 'baseline', 'bottom' +ROTATE_90, ROTATE_180, ROTATE_270 = 90, 180, 270 + + +class Text2D(object): + + _SHADERS = { + 'vertex': """ + #version 120 + + attribute vec2 position; + attribute vec2 texCoords; + uniform mat4 matrix; + + varying vec2 vCoords; + + void main(void) { + gl_Position = matrix * vec4(position, 0.0, 1.0); + vCoords = texCoords; + } + """, + 'fragment': """ + #version 120 + + uniform sampler2D texText; + uniform vec4 color; + uniform vec4 bgColor; + + varying vec2 vCoords; + + void main(void) { + gl_FragColor = mix(bgColor, color, texture2D(texText, vCoords).r); + } + """ + } + + _TEX_COORDS = numpy.array(((0., 0.), (1., 0.), (0., 1.), (1., 1.)), + dtype=numpy.float32).ravel() + + _program = Program(_SHADERS['vertex'], + _SHADERS['fragment'], + attrib0='position') + + _textures = {} + + _rasterTextCache = {} + """Internal cache storing already rasterized text""" + # TODO limit cache size and discard least recent used + + def __init__(self, text, x=0, y=0, + color=(0., 0., 0., 1.), + bgColor=None, + align=LEFT, valign=BASELINE, + rotate=0): + self._vertices = None + self._text = text + self.x = x + self.y = y + self.color = color + self.bgColor = bgColor + + if align not in (LEFT, CENTER, RIGHT): + raise ValueError( + "Horizontal alignment not supported: {0}".format(align)) + self._align = align + + if valign not in (TOP, CENTER, BASELINE, BOTTOM): + raise ValueError( + "Vertical alignment not supported: {0}".format(valign)) + self._valign = valign + + self._rotate = numpy.radians(rotate) + + @classmethod + def _getTexture(cls, text): + key = getGLContext(), text + if key not in cls._textures: + image, offset = font.rasterText(text, + font.getDefaultFontFamily()) + cls._textures[key] = (Texture(gl.GL_RED, + data=image, + minFilter=gl.GL_NEAREST, + magFilter=gl.GL_NEAREST, + wrap=(gl.GL_CLAMP_TO_EDGE, + gl.GL_CLAMP_TO_EDGE)), + offset) + + return cls._textures[key] + + @property + def text(self): + return self._text + + @property + def size(self): # TODO very poor implementation + image, offset = font.rasterText(self.text, + font.getDefaultFontFamily()) + return image.shape[1], image.shape[0] + + def getVertices(self, offset, shape): + height, width = shape + + if self._align == LEFT: + xOrig = 0 + elif self._align == RIGHT: + xOrig = - width + else: # CENTER + xOrig = - width // 2 + + if self._valign == BASELINE: + yOrig = - offset + elif self._valign == TOP: + yOrig = 0 + elif self._valign == BOTTOM: + yOrig = - height + else: # CENTER + yOrig = - height // 2 + + vertices = numpy.array(( + (xOrig, yOrig), + (xOrig + width, yOrig), + (xOrig, yOrig + height), + (xOrig + width, yOrig + height)), dtype=numpy.float32) + + cos, sin = numpy.cos(self._rotate), numpy.sin(self._rotate) + vertices = numpy.ascontiguousarray(numpy.transpose(numpy.array(( + cos * vertices[:, 0] - sin * vertices[:, 1], + sin * vertices[:, 0] + cos * vertices[:, 1]), + dtype=numpy.float32))) + + return vertices + + def render(self, matrix): + if not self.text: + return + + prog = self._program + prog.use() + + texUnit = 0 + texture, offset = self._getTexture(self.text) + + gl.glUniform1i(prog.uniforms['texText'], texUnit) + + gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, + matrix * mat4Translate(int(self.x), int(self.y))) + + gl.glUniform4f(prog.uniforms['color'], *self.color) + if self.bgColor is not None: + bgColor = self.bgColor + else: + bgColor = self.color[0], self.color[1], self.color[2], 0. + gl.glUniform4f(prog.uniforms['bgColor'], *bgColor) + + vertices = self.getVertices(offset, texture.shape) + + posAttrib = prog.attributes['position'] + gl.glEnableVertexAttribArray(posAttrib) + gl.glVertexAttribPointer(posAttrib, + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + 0, + vertices) + + texAttrib = prog.attributes['texCoords'] + gl.glEnableVertexAttribArray(texAttrib) + gl.glVertexAttribPointer(texAttrib, + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + 0, + self._TEX_COORDS) + + with texture: + gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, 4) diff --git a/silx/gui/plot/backends/glutils/GLTexture.py b/silx/gui/plot/backends/glutils/GLTexture.py new file mode 100644 index 0000000..25dd9f1 --- /dev/null +++ b/silx/gui/plot/backends/glutils/GLTexture.py @@ -0,0 +1,239 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +"""This module provides classes wrapping OpenGL texture.""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "03/04/2017" + + +from ctypes import c_void_p +import logging + +import numpy + +from ...._glutils import gl, Texture, numpyToGLType + + +_logger = logging.getLogger(__name__) + + +def _checkTexture2D(internalFormat, shape, + format_=None, type_=gl.GL_FLOAT, border=0): + """Check if texture size with provided parameters is supported + + :rtype: bool + """ + height, width = shape + gl.glTexImage2D(gl.GL_PROXY_TEXTURE_2D, 0, internalFormat, + width, height, border, + format_ or internalFormat, + type_, c_void_p(0)) + width = gl.glGetTexLevelParameteriv( + gl.GL_PROXY_TEXTURE_2D, 0, gl.GL_TEXTURE_WIDTH) + return bool(width) + + +MIN_TEXTURE_SIZE = 64 + + +def _getMaxSquareTexture2DSize(internalFormat=gl.GL_RGBA, + format_=None, + type_=gl.GL_FLOAT, + border=0): + """Returns a supported size for a corresponding square texture + + :returns: GL_MAX_TEXTURE_SIZE or a smaller supported size (not optimal) + :rtype: int + """ + # Is this useful? + maxTexSize = gl.glGetIntegerv(gl.GL_MAX_TEXTURE_SIZE) + while maxTexSize > MIN_TEXTURE_SIZE and \ + not _checkTexture2D(internalFormat, (maxTexSize, maxTexSize), + format_, type_, border): + maxTexSize //= 2 + return max(MIN_TEXTURE_SIZE, maxTexSize) + + +class Image(object): + """Image of any size eventually using multiple textures or larger texture + """ + + _WRAP = (gl.GL_CLAMP_TO_EDGE, gl.GL_CLAMP_TO_EDGE) + _MIN_FILTER = gl.GL_NEAREST + _MAG_FILTER = gl.GL_NEAREST + + def __init__(self, internalFormat, data, format_=None, texUnit=0): + self.internalFormat = internalFormat + self.height, self.width = data.shape[0:2] + type_ = numpyToGLType(data.dtype) + + if _checkTexture2D(internalFormat, data.shape[0:2], format_, type_): + texture = Texture(internalFormat, + data, + format_, + texUnit=texUnit, + minFilter=self._MIN_FILTER, + magFilter=self._MAG_FILTER, + wrap=self._WRAP) + vertices = numpy.array(( + (0., 0., 0., 0.), + (self.width, 0., 1., 0.), + (0., self.height, 0., 1.), + (self.width, self.height, 1., 1.)), dtype=numpy.float32) + self.tiles = ((texture, vertices, + {'xOrigData': 0, 'yOrigData': 0, + 'wData': self.width, 'hData': self.height}),) + + else: + # Handle dimension too large: make tiles + maxTexSize = _getMaxSquareTexture2DSize(internalFormat, + format_, type_) + + nCols = (self.width+maxTexSize-1) // maxTexSize + colWidths = [self.width // nCols] * nCols + colWidths[-1] += self.width % nCols + + nRows = (self.height+maxTexSize-1) // maxTexSize + rowHeights = [self.height//nRows] * nRows + rowHeights[-1] += self.height % nRows + + tiles = [] + yOrig = 0 + for hData in rowHeights: + xOrig = 0 + for wData in colWidths: + if (hData < MIN_TEXTURE_SIZE or wData < MIN_TEXTURE_SIZE) \ + and not _checkTexture2D(internalFormat, + (hData, wData), + format_, + type_): + # Ensure texture size is at least MIN_TEXTURE_SIZE + tH = max(hData, MIN_TEXTURE_SIZE) + tW = max(wData, MIN_TEXTURE_SIZE) + + uMax, vMax = float(wData)/tW, float(hData)/tH + + # TODO issue with type_ and alignment + texture = Texture(internalFormat, + data=None, + format_=format_, + shape=(tH, tW), + texUnit=texUnit, + minFilter=self._MIN_FILTER, + magFilter=self._MAG_FILTER, + wrap=self._WRAP) + # TODO handle unpack + texture.update(format_, + data[yOrig:yOrig+hData, + xOrig:xOrig+wData]) + # texture.update(format_, type_, data, + # width=wData, height=hData, + # unpackRowLength=width, + # unpackSkipPixels=xOrig, + # unpackSkipRows=yOrig) + else: + uMax, vMax = 1, 1 + # TODO issue with type_ and unpacking tiles + # TODO idea to handle unpack: use array strides + # As it is now, it will make a copy + texture = Texture(internalFormat, + data[yOrig:yOrig+hData, + xOrig:xOrig+wData], + format_, + shape=(hData, wData), + texUnit=texUnit, + minFilter=self._MIN_FILTER, + magFilter=self._MAG_FILTER, + wrap=self._WRAP) + # TODO + # unpackRowLength=width, + # unpackSkipPixels=xOrig, + # unpackSkipRows=yOrig) + vertices = numpy.array(( + (xOrig, yOrig, 0., 0.), + (xOrig + wData, yOrig, uMax, 0.), + (xOrig, yOrig + hData, 0., vMax), + (xOrig + wData, yOrig + hData, uMax, vMax)), + dtype=numpy.float32) + tiles.append((texture, vertices, + {'xOrigData': xOrig, 'yOrigData': yOrig, + 'wData': wData, 'hData': hData})) + xOrig += wData + yOrig += hData + self.tiles = tuple(tiles) + + def discard(self): + for texture, vertices, _ in self.tiles: + texture.discard() + del self.tiles + + def updateAll(self, format_, data, texUnit=0): + if not hasattr(self, 'tiles'): + raise RuntimeError("No texture, discard has already been called") + + assert data.shape[:2] == (self.height, self.width) + if len(self.tiles) == 1: + self.tiles[0][0].update(format_, data, texUnit=texUnit) + else: + for texture, _, info in self.tiles: + yOrig, xOrig = info['yOrigData'], info['xOrigData'] + height, width = info['hData'], info['wData'] + texture.update(format_, + data[yOrig:yOrig+height, xOrig:xOrig+width], + texUnit=texUnit) + # TODO check + # width=info['wData'], height=info['hData'], + # texUnit=texUnit, unpackAlign=unpackAlign, + # unpackRowLength=self.width, + # unpackSkipPixels=info['xOrigData'], + # unpackSkipRows=info['yOrigData']) + + def render(self, posAttrib, texAttrib, texUnit=0): + try: + tiles = self.tiles + except AttributeError: + raise RuntimeError("No texture, discard has already been called") + + for texture, vertices, _ in tiles: + texture.bind(texUnit) + + stride = vertices.shape[-1] * vertices.itemsize + gl.glEnableVertexAttribArray(posAttrib) + gl.glVertexAttribPointer(posAttrib, + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + stride, vertices) + + texCoordsPtr = c_void_p(vertices.ctypes.data + + 2 * vertices.itemsize) + gl.glEnableVertexAttribArray(texAttrib) + gl.glVertexAttribPointer(texAttrib, + 2, + gl.GL_FLOAT, + gl.GL_FALSE, + stride, texCoordsPtr) + gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(vertices)) diff --git a/silx/gui/plot/backends/glutils/PlotImageFile.py b/silx/gui/plot/backends/glutils/PlotImageFile.py new file mode 100644 index 0000000..e4ebe24 --- /dev/null +++ b/silx/gui/plot/backends/glutils/PlotImageFile.py @@ -0,0 +1,149 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +"""Function to save an image to a file.""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "03/04/2017" + + +import base64 +import struct +import sys +import zlib + + +# Image writer ################################################################ + +def convertRGBDataToPNG(data): + """Convert a RGB bitmap to PNG. + + It only supports RGB bitmap with one byte per channel stored as a 3D array. + See `Definitive Guide <http://www.libpng.org/pub/png/book/>`_ and + `Specification <http://www.libpng.org/pub/png/spec/1.2/>`_ for details. + + :param data: A 3D array (h, w, rgb) storing an RGB image + :type data: numpy.ndarray of unsigned bytes + :returns: The PNG encoded data + :rtype: bytes + """ + height, width = data.shape[0], data.shape[1] + depth = 8 # 8 bit per channel + colorType = 2 # 'truecolor' = RGB + interlace = 0 # No + + IHDRdata = struct.pack(">ccccIIBBBBB", b'I', b'H', b'D', b'R', + width, height, depth, colorType, + 0, 0, interlace) + + # Add filter 'None' before each scanline + preparedData = b'\x00' + b'\x00'.join(line.tostring() for line in data) + compressedData = zlib.compress(preparedData, 8) + + IDATdata = struct.pack("cccc", b'I', b'D', b'A', b'T') + IDATdata += compressedData + + return b''.join([ + b'\x89PNG\r\n\x1a\n', # PNG signature + # IHDR chunk: Image Header + struct.pack(">I", 13), # length + IHDRdata, + struct.pack(">I", zlib.crc32(IHDRdata) & 0xffffffff), # CRC + # IDAT chunk: Payload + struct.pack(">I", len(compressedData)), + IDATdata, + struct.pack(">I", zlib.crc32(IDATdata) & 0xffffffff), # CRC + b'\x00\x00\x00\x00IEND\xaeB`\x82' # IEND chunk: footer + ]) + + +def saveImageToFile(data, fileNameOrObj, fileFormat): + """Save a RGB image to a file. + + :param data: A 3D array (h, w, 3) storing an RGB image. + :type data: numpy.ndarray with of unsigned bytes. + :param fileNameOrObj: Filename or object to use to write the image. + :type fileNameOrObj: A str or a 'file-like' object with a 'write' method. + :param str fileFormat: The type of the file in: 'png', 'ppm', 'svg', 'tiff'. + """ + assert len(data.shape) == 3 + assert data.shape[2] == 3 + assert fileFormat in ('png', 'ppm', 'svg', 'tiff') + + if not hasattr(fileNameOrObj, 'write'): + if sys.version < "3.0": + fileObj = open(fileNameOrObj, "wb") + else: + fileObj = open(fileNameOrObj, "w", newline='') + else: # Use as a file-like object + fileObj = fileNameOrObj + + if fileFormat == 'svg': + height, width = data.shape[:2] + base64Data = base64.b64encode(convertRGBDataToPNG(data)) + + fileObj.write( + '<?xml version="1.0" encoding="UTF-8" standalone="no"?>\n') + fileObj.write('<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"\n') + fileObj.write( + ' "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">\n') + fileObj.write('<svg xmlns:xlink="http://www.w3.org/1999/xlink"\n') + fileObj.write(' xmlns="http://www.w3.org/2000/svg"\n') + fileObj.write(' version="1.1"\n') + fileObj.write(' width="%d"\n' % width) + fileObj.write(' height="%d">\n' % height) + fileObj.write(' <image xlink:href="data:image/png;base64,') + fileObj.write(base64Data.decode('ascii')) + fileObj.write('"\n') + fileObj.write(' x="0"\n') + fileObj.write(' y="0"\n') + fileObj.write(' width="%d"\n' % width) + fileObj.write(' height="%d"\n' % height) + fileObj.write(' id="image" />\n') + fileObj.write('</svg>') + + elif fileFormat == 'ppm': + height, width = data.shape[:2] + + fileObj.write('P6\n') + fileObj.write('%d %d\n' % (width, height)) + fileObj.write('255\n') + fileObj.write(data.tostring()) + + elif fileFormat == 'png': + fileObj.write(convertRGBDataToPNG(data)) + + elif fileFormat == 'tiff': + if fileObj == fileNameOrObj: + raise NotImplementedError( + 'Save TIFF to a file-like object not implemented') + + from silx.third_party.TiffIO import TiffIO + + tif = TiffIO(fileNameOrObj, mode='wb+') + tif.writeImage(data, info={'Title': 'PyMCA GL Snapshot'}) + + if fileObj != fileNameOrObj: + fileObj.close() diff --git a/silx/gui/plot/backends/glutils/__init__.py b/silx/gui/plot/backends/glutils/__init__.py new file mode 100644 index 0000000..771de39 --- /dev/null +++ b/silx/gui/plot/backends/glutils/__init__.py @@ -0,0 +1,44 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +"""This module provides convenient classes for the OpenGL rendering backend. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "03/04/2017" + + +import logging + + +_logger = logging.getLogger(__name__) + + +from .GLPlotCurve import * # noqa +from .GLPlotFrame import * # noqa +from .GLPlotImage import * # noqa +from .GLSupport import * # noqa +from .GLText import * # noqa +from .GLTexture import * # noqa diff --git a/silx/gui/plot/items/__init__.py b/silx/gui/plot/items/__init__.py new file mode 100644 index 0000000..b16fe40 --- /dev/null +++ b/silx/gui/plot/items/__init__.py @@ -0,0 +1,43 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This package provides classes that describes :class:`.Plot` content. + +Instances of those classes are returned by :class:`.Plot` methods that give +access to its content such as :meth:`.Plot.getCurve`, :meth:`.Plot.getImage`. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "06/03/2017" + +from .core import (Item, LabelsMixIn, DraggableMixIn, ColormapMixIn, # noqa + SymbolMixIn, ColorMixIn, YAxisMixIn, FillMixIn, # noqa + AlphaMixIn, LineMixIn) # noqa +from .curve import Curve # noqa +from .histogram import Histogram # noqa +from .image import ImageBase, ImageData, ImageRgba # noqa +from .shape import Shape # noqa +from .scatter import Scatter # noqa +from .marker import Marker, XMarker, YMarker # noqa diff --git a/silx/gui/plot/items/core.py b/silx/gui/plot/items/core.py new file mode 100644 index 0000000..72bfd9a --- /dev/null +++ b/silx/gui/plot/items/core.py @@ -0,0 +1,839 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides the base class for items of the :class:`Plot`. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "26/04/2017" + +from copy import deepcopy +import logging +import weakref +import numpy +from silx.third_party import six + +from .. import Colors + + + +_logger = logging.getLogger(__name__) + + +class Item(object): + """Description of an item of the plot""" + + _DEFAULT_Z_LAYER = 0 + """Default layer for overlay rendering""" + + _DEFAULT_LEGEND = '' + """Default legend of items""" + + _DEFAULT_SELECTABLE = False + """Default selectable state of items""" + + def __init__(self): + self._dirty = True + self._plotRef = None + self._visible = True + self._legend = self._DEFAULT_LEGEND + self._selectable = self._DEFAULT_SELECTABLE + self._z = self._DEFAULT_Z_LAYER + self._info = None + self._xlabel = None + self._ylabel = None + + self._backendRenderer = None + + def getPlot(self): + """Returns Plot this item belongs to. + + :rtype: Plot or None + """ + return None if self._plotRef is None else self._plotRef() + + def _setPlot(self, plot): + """Set the plot this item belongs to. + + WARNING: This should only be called from the Plot. + + :param Plot plot: The Plot instance. + """ + if plot is not None and self._plotRef is not None: + raise RuntimeError('Trying to add a node at two places.') + self._plotRef = None if plot is None else weakref.ref(plot) + self._updated() + + def getBounds(self): # TODO return a Bounds object rather than a tuple + """Returns the bounding box of this item in data coordinates + + :returns: (xmin, xmax, ymin, ymax) or None + :rtype: 4-tuple of float or None + """ + return self._getBounds() + + def _getBounds(self): + """:meth:`getBounds` implementation to override by sub-class""" + return None + + def isVisible(self): + """True if item is visible, False otherwise + + :rtype: bool + """ + return self._visible + + def setVisible(self, visible): + """Set visibility of item. + + :param bool visible: True to display it, False otherwise + """ + visible = bool(visible) + if visible != self._visible: + self._visible = visible + # When visibility has changed, always mark as dirty + self._updated(checkVisibility=False) + + def isOverlay(self): + """Return true if item is drawn as an overlay. + + :rtype: bool + """ + return False + + def getLegend(self): + """Returns the legend of this item (str)""" + return self._legend + + def _setLegend(self, legend): + """Set the legend. + + This is private as it is used by the plot as an identifier + + :param str legend: Item legend + """ + legend = str(legend) if legend is not None else self._DEFAULT_LEGEND + self._legend = legend + + def isSelectable(self): + """Returns true if item is selectable (bool)""" + return self._selectable + + def _setSelectable(self, selectable): # TODO support update + """Set whether item is selectable or not. + + This is private for now as change is not handled. + + :param bool selectable: True to make item selectable + """ + self._selectable = bool(selectable) + + def getZValue(self): + """Returns the layer on which to draw this item (int)""" + return self._z + + def setZValue(self, z): + z = int(z) if z is not None else self._DEFAULT_Z_LAYER + if z != self._z: + self._z = z + self._updated() + + def getInfo(self, copy=True): + """Returns the info associated to this item + + :param bool copy: True to get a deepcopy, False otherwise. + """ + return deepcopy(self._info) if copy else self._info + + def setInfo(self, info, copy=True): + if copy: + info = deepcopy(info) + self._info = info + + def _updated(self, checkVisibility=True): + """Mark the item as dirty (i.e., needing update). + + This also triggers Plot.replot. + + :param bool checkVisibility: True to only mark as dirty if visible, + False to always mark as dirty. + """ + if not checkVisibility or self.isVisible(): + if not self._dirty: + self._dirty = True + # TODO: send event instead of explicit call + plot = self.getPlot() + if plot is not None: + plot._itemRequiresUpdate(self) + + def _update(self, backend): + """Called by Plot to update the backend for this item. + + This is meant to be called asynchronously from _updated. + This optimizes the number of call to _update. + + :param backend: The backend to update + """ + if self._dirty: + # Remove previous renderer from backend if any + self._removeBackendRenderer(backend) + + # If not visible, do not add renderer to backend + if self.isVisible(): + self._backendRenderer = self._addBackendRenderer(backend) + + self._dirty = False + + def _addBackendRenderer(self, backend): + """Override in subclass to add specific backend renderer. + + :param BackendBase backend: The backend to update + :return: The renderer handle to store or None if no renderer in backend + """ + return None + + def _removeBackendRenderer(self, backend): + """Override in subclass to remove specific backend renderer. + + :param BackendBase backend: The backend to update + """ + if self._backendRenderer is not None: + backend.remove(self._backendRenderer) + self._backendRenderer = None + + +# Mix-in classes ############################################################## + +class LabelsMixIn(object): + """Mix-in class for items with x and y labels + + Setters are private, otherwise it needs to check the plot + current active curve and access the internal current labels. + """ + + def __init__(self): + self._xlabel = None + self._ylabel = None + + def getXLabel(self): + """Return the X axis label associated to this curve + + :rtype: str or None + """ + return self._xlabel + + def _setXLabel(self, label): + """Set the X axis label associated with this curve + + :param str label: The X axis label + """ + self._xlabel = str(label) + + def getYLabel(self): + """Return the Y axis label associated to this curve + + :rtype: str or None + """ + return self._ylabel + + def _setYLabel(self, label): + """Set the Y axis label associated with this curve + + :param str label: The Y axis label + """ + self._ylabel = str(label) + + +class DraggableMixIn(object): + """Mix-in class for draggable items""" + + def __init__(self): + self._draggable = False + + def isDraggable(self): + """Returns true if image is draggable + + :rtype: bool + """ + return self._draggable + + def _setDraggable(self, draggable): # TODO support update + """Set if image is draggable or not. + + This is private for not as it does not support update. + + :param bool draggable: + """ + self._draggable = bool(draggable) + + +class ColormapMixIn(object): + """Mix-in class for items with colormap""" + + _DEFAULT_COLORMAP = {'name': 'gray', 'normalization': 'linear', + 'autoscale': True, 'vmin': 0.0, 'vmax': 1.0} + """Default colormap of the item""" + + def __init__(self): + self._colormap = self._DEFAULT_COLORMAP + + def getColormap(self): + """Return the used colormap""" + return self._colormap.copy() + + def setColormap(self, colormap): + """Set the colormap of this image + + :param dict colormap: colormap description + """ + self._colormap = colormap.copy() + # TODO colormap comparison + colormap object and events on modification + self._updated() + + +class SymbolMixIn(object): + """Mix-in class for items with symbol type""" + + _DEFAULT_SYMBOL = '' + """Default marker of the item""" + + _DEFAULT_SYMBOL_SIZE = 6.0 + """Default marker size of the item""" + + def __init__(self): + self._symbol = self._DEFAULT_SYMBOL + self._symbol_size = self._DEFAULT_SYMBOL_SIZE + + def getSymbol(self): + """Return the point marker type. + + Marker type:: + + - 'o' circle + - '.' point + - ',' pixel + - '+' cross + - 'x' x-cross + - 'd' diamond + - 's' square + + :rtype: str + """ + return self._symbol + + def setSymbol(self, symbol): + """Set the marker type + + See :meth:`getSymbol`. + + :param str symbol: Marker type + """ + assert symbol in ('o', '.', ',', '+', 'x', 'd', 's', '', None) + if symbol is None: + symbol = self._DEFAULT_SYMBOL + if symbol != self._symbol: + self._symbol = symbol + self._updated() + + def getSymbolSize(self): + """Return the point marker size in points. + + :rtype: float + """ + return self._symbol_size + + def setSymbolSize(self, size): + """Set the point marker size in points. + + See :meth:`getSymbolSize`. + + :param str symbol: Marker type + """ + if size is None: + size = self._DEFAULT_SYMBOL_SIZE + if size != self._symbol_size: + self._symbol_size = size + self._updated() + + +class LineMixIn(object): + """Mix-in class for item with line""" + + _DEFAULT_LINEWIDTH = 1. + """Default line width""" + + _DEFAULT_LINESTYLE = '-' + """Default line style""" + + def __init__(self): + self._linewidth = self._DEFAULT_LINEWIDTH + self._linestyle = self._DEFAULT_LINESTYLE + + def getLineWidth(self): + """Return the curve line width in pixels (int)""" + return self._linewidth + + def setLineWidth(self, width): + """Set the width in pixel of the curve line + + See :meth:`getLineWidth`. + + :param float width: Width in pixels + """ + width = float(width) + if width != self._linewidth: + self._linewidth = width + self._updated() + + def getLineStyle(self): + """Return the type of the line + + Type of line:: + + - ' ' no line + - '-' solid line + - '--' dashed line + - '-.' dash-dot line + - ':' dotted line + + :rtype: str + """ + return self._linestyle + + def setLineStyle(self, style): + """Set the style of the curve line. + + See :meth:`getLineStyle`. + + :param str style: Line style + """ + style = str(style) + assert style in ('', ' ', '-', '--', '-.', ':', None) + if style is None: + style = self._DEFAULT_LINESTYLE + if style != self._linestyle: + self._linestyle = style + self._updated() + + +class ColorMixIn(object): + """Mix-in class for item with color""" + + _DEFAULT_COLOR = (0., 0., 0., 1.) + """Default color of the item""" + + def __init__(self): + self._color = self._DEFAULT_COLOR + + def getColor(self): + """Returns the RGBA color of the item + + :rtype: 4-tuple of float in [0, 1] + """ + return self._color + + def setColor(self, color, copy=True): + """Set item color + + :param color: color(s) to be used + :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or + one of the predefined color names defined in Colors.py + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + """ + if isinstance(color, six.string_types): + color = Colors.rgba(color) + else: + color = numpy.array(color, copy=copy) + # TODO more checks + improve color array support + if color.ndim == 1: # Single RGBA color + color = Colors.rgba(color) + else: # Array of colors + assert color.ndim == 2 + + self._color = color + self._updated() + + +class YAxisMixIn(object): + """Mix-in class for item with yaxis""" + + _DEFAULT_YAXIS = 'left' + """Default Y axis the item belongs to""" + + def __init__(self): + self._yaxis = self._DEFAULT_YAXIS + + def getYAxis(self): + """Returns the Y axis this curve belongs to. + + Either 'left' or 'right'. + + :rtype: str + """ + return self._yaxis + + def setYAxis(self, yaxis): + """Set the Y axis this curve belongs to. + + :param str yaxis: 'left' or 'right' + """ + yaxis = str(yaxis) + assert yaxis in ('left', 'right') + if yaxis != self._yaxis: + self._yaxis = yaxis + self._updated() + + +class FillMixIn(object): + """Mix-in class for item with fill""" + + def __init__(self): + self._fill = False + + def isFill(self): + """Returns whether the item is filled or not. + + :rtype: bool + """ + return self._fill + + def setFill(self, fill): + """Set whether to fill the item or not. + + :param bool fill: + """ + fill = bool(fill) + if fill != self._fill: + self._fill = fill + self._updated() + + +class AlphaMixIn(object): + """Mix-in class for item with opacity""" + + def __init__(self): + self._alpha = 1. + + def getAlpha(self): + """Returns the opacity of the item + + :rtype: float in [0, 1.] + """ + return self._alpha + + def setAlpha(self, alpha): + """Set the opacity of the item + + .. note:: + + If the colormap already has some transparency, this alpha + adds additional transparency. The alpha channel of the colormap + is multiplied by this value. + + :param alpha: Opacity of the item, between 0 (full transparency) + and 1. (full opacity) + :type alpha: float + """ + alpha = float(alpha) + alpha = max(0., min(alpha, 1.)) # Clip alpha to [0., 1.] range + if alpha != self._alpha: + self._alpha = alpha + self._updated() + + +class Points(Item, SymbolMixIn, AlphaMixIn): + """Base class for :class:`Curve` and :class:`Scatter`""" + # note: _logFilterData must be overloaded if you overload + # getData to change its signature + + _DEFAULT_Z_LAYER = 1 + """Default overlay layer for points, + on top of images.""" + + def __init__(self): + Item.__init__(self) + SymbolMixIn.__init__(self) + AlphaMixIn.__init__(self) + self._x = () + self._y = () + self._xerror = None + self._yerror = None + + # Store filtered data for x > 0 and/or y > 0 + self._filteredCache = {} + self._clippedCache = {} + + # Store bounds depending on axes filtering >0: + # key is (isXPositiveFilter, isYPositiveFilter) + self._boundsCache = {} + + @staticmethod + def _logFilterError(value, error): + """Filter/convert error values if they go <= 0. + + Replace error leading to negative values by nan + + :param numpy.ndarray value: 1D array of values + :param numpy.ndarray error: + Array of errors: scalar, N, Nx1 or 2xN or None. + :return: Filtered error so error bars are never negative + """ + if error is not None: + # Convert Nx1 to N + if error.ndim == 2 and error.shape[1] == 1 and len(value) != 1: + error = numpy.ravel(error) + + # Supports error being scalar, N or 2xN array + errorClipped = (value - numpy.atleast_2d(error)[0]) <= 0 + + if numpy.any(errorClipped): # Need filtering + + # expand errorbars to 2xN + if error.size == 1: # Scalar + error = numpy.full( + (2, len(value)), error, dtype=numpy.float) + + elif error.ndim == 1: # N array + newError = numpy.empty((2, len(value)), + dtype=numpy.float) + newError[0, :] = error + newError[1, :] = error + error = newError + + elif error.size == 2 * len(value): # 2xN array + error = numpy.array( + error, copy=True, dtype=numpy.float) + + else: + _logger.error("Unhandled error array") + return error + + error[0, errorClipped] = numpy.nan + + return error + + def _getClippingBoolArray(self, xPositive, yPositive): + """Compute a boolean array to filter out points with negative + coordinates on log axes. + + :param bool xPositive: True to filter arrays according to X coords. + :param bool yPositive: True to filter arrays according to Y coords. + :rtype: boolean numpy.ndarray + """ + assert xPositive or yPositive + if (xPositive, yPositive) not in self._clippedCache: + x = self.getXData(copy=False) + y = self.getYData(copy=False) + xclipped = (x <= 0) if xPositive else False + yclipped = (y <= 0) if yPositive else False + self._clippedCache[(xPositive, yPositive)] = \ + numpy.logical_or(xclipped, yclipped) + return self._clippedCache[(xPositive, yPositive)] + + def _logFilterData(self, xPositive, yPositive): + """Filter out values with x or y <= 0 on log axes + + :param bool xPositive: True to filter arrays according to X coords. + :param bool yPositive: True to filter arrays according to Y coords. + :return: The filter arrays or unchanged object if filtering not needed + :rtype: (x, y, xerror, yerror) + """ + x = self.getXData(copy=False) + y = self.getYData(copy=False) + xerror = self.getXErrorData(copy=False) + yerror = self.getYErrorData(copy=False) + + if xPositive or yPositive: + clipped = self._getClippingBoolArray(xPositive, yPositive) + + if numpy.any(clipped): + # copy to keep original array and convert to float + x = numpy.array(x, copy=True, dtype=numpy.float) + x[clipped] = numpy.nan + y = numpy.array(y, copy=True, dtype=numpy.float) + y[clipped] = numpy.nan + + if xPositive and xerror is not None: + xerror = self._logFilterError(x, xerror) + + if yPositive and yerror is not None: + yerror = self._logFilterError(y, yerror) + + return x, y, xerror, yerror + + def _getBounds(self): + if self.getXData(copy=False).size == 0: # Empty data + return None + + plot = self.getPlot() + if plot is not None: + xPositive = plot.isXAxisLogarithmic() + yPositive = plot.isYAxisLogarithmic() + else: + xPositive = False + yPositive = False + + # TODO bounds do not take error bars into account + if (xPositive, yPositive) not in self._boundsCache: + # use the getData class method because instance method can be + # overloaded to return additional arrays + data = Points.getData(self, copy=False, + displayed=True) + if len(data) == 5: + # hack to avoid duplicating caching mechanism in Scatter + # (happens when cached data is used, caching done using + # Scatter._logFilterData) + x, y, xerror, yerror = data[0], data[1], data[3], data[4] + else: + x, y, xerror, yerror = data + + self._boundsCache[(xPositive, yPositive)] = ( + numpy.nanmin(x), + numpy.nanmax(x), + numpy.nanmin(y), + numpy.nanmax(y) + ) + return self._boundsCache[(xPositive, yPositive)] + + def _getCachedData(self): + """Return cached filtered data if applicable, + i.e. if any axis is in log scale. + Return None if caching is not applicable.""" + plot = self.getPlot() + if plot is not None: + xPositive = plot.isXAxisLogarithmic() + yPositive = plot.isYAxisLogarithmic() + if xPositive or yPositive: + # At least one axis has log scale, filter data + if (xPositive, yPositive) not in self._filteredCache: + self._filteredCache[(xPositive, yPositive)] = \ + self._logFilterData(xPositive, yPositive) + return self._filteredCache[(xPositive, yPositive)] + return None + + def getData(self, copy=True, displayed=False): + """Returns the x, y values of the curve points and xerror, yerror + + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :param bool displayed: True to only get curve points that are displayed + in the plot. Default: False + Note: If plot has log scale, negative points + are not displayed. + :returns: (x, y, xerror, yerror) + :rtype: 4-tuple of numpy.ndarray + """ + if displayed: # filter data according to plot state + cached_data = self._getCachedData() + if cached_data is not None: + return cached_data + + return (self.getXData(copy), + self.getYData(copy), + self.getXErrorData(copy), + self.getYErrorData(copy)) + + def getXData(self, copy=True): + """Returns the x coordinates of the data points + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: numpy.ndarray + """ + return numpy.array(self._x, copy=copy) + + def getYData(self, copy=True): + """Returns the y coordinates of the data points + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: numpy.ndarray + """ + return numpy.array(self._y, copy=copy) + + def getXErrorData(self, copy=True): + """Returns the x error of the points + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: numpy.ndarray or None + """ + if self._xerror is None: + return None + else: + return numpy.array(self._xerror, copy=copy) + + def getYErrorData(self, copy=True): + """Returns the y error of the points + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: numpy.ndarray or None + """ + if self._yerror is None: + return None + else: + return numpy.array(self._yerror, copy=copy) + + def setData(self, x, y, xerror=None, yerror=None, copy=True): + """Set the data of the curve. + + :param numpy.ndarray x: The data corresponding to the x coordinates. + :param numpy.ndarray y: The data corresponding to the y coordinates. + :param xerror: Values with the uncertainties on the x values + :type xerror: A float, or a numpy.ndarray of float32. + If it is an array, it can either be a 1D array of + same length as the data or a 2D array with 2 rows + of same length as the data: row 0 for positive errors, + row 1 for negative errors. + :param yerror: Values with the uncertainties on the y values. + :type yerror: A float, or a numpy.ndarray of float32. See xerror. + :param bool copy: True make a copy of the data (default), + False to use provided arrays. + """ + x = numpy.array(x, copy=copy) + y = numpy.array(y, copy=copy) + assert len(x) == len(y) + assert x.ndim == y.ndim == 1 + + if xerror is not None: + xerror = numpy.array(xerror, copy=copy) + if yerror is not None: + yerror = numpy.array(yerror, copy=copy) + # TODO checks on xerror, yerror + self._x, self._y = x, y + self._xerror, self._yerror = xerror, yerror + + self._boundsCache = {} # Reset cached bounds + self._filteredCache = {} # Reset cached filtered data + self._clippedCache = {} # Reset cached clipped bool array + + self._updated() + # TODO hackish data range implementation + if self.isVisible(): + plot = self.getPlot() + if plot is not None: + plot._invalidateDataRange() diff --git a/silx/gui/plot/items/curve.py b/silx/gui/plot/items/curve.py new file mode 100644 index 0000000..d25ae00 --- /dev/null +++ b/silx/gui/plot/items/curve.py @@ -0,0 +1,192 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides the :class:`Curve` item of the :class:`Plot`. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "06/03/2017" + + +import logging + +import numpy + +from .. import Colors +from .core import (Points, LabelsMixIn, SymbolMixIn, + ColorMixIn, YAxisMixIn, FillMixIn, LineMixIn) + + +_logger = logging.getLogger(__name__) + + +class Curve(Points, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn, LineMixIn): + """Description of a curve""" + + _DEFAULT_Z_LAYER = 1 + """Default overlay layer for curves""" + + _DEFAULT_SELECTABLE = True + """Default selectable state for curves""" + + _DEFAULT_LINEWIDTH = 1. + """Default line width of the curve""" + + _DEFAULT_LINESTYLE = '-' + """Default line style of the curve""" + + _DEFAULT_HIGHLIGHT_COLOR = (0, 0, 0, 255) + """Default highlight color of the item""" + + def __init__(self): + Points.__init__(self) + ColorMixIn.__init__(self) + YAxisMixIn.__init__(self) + FillMixIn.__init__(self) + LabelsMixIn.__init__(self) + LineMixIn.__init__(self) + + self._highlightColor = self._DEFAULT_HIGHLIGHT_COLOR + self._highlighted = False + + def _addBackendRenderer(self, backend): + """Update backend renderer""" + # Filter-out values <= 0 + xFiltered, yFiltered, xerror, yerror = self.getData( + copy=False, displayed=True) + + if len(xFiltered) == 0: + return None # No data to display, do not add renderer to backend + + return backend.addCurve(xFiltered, yFiltered, self.getLegend(), + color=self.getCurrentColor(), + symbol=self.getSymbol(), + linestyle=self.getLineStyle(), + linewidth=self.getLineWidth(), + yaxis=self.getYAxis(), + xerror=xerror, + yerror=yerror, + z=self.getZValue(), + selectable=self.isSelectable(), + fill=self.isFill(), + alpha=self.getAlpha(), + symbolsize=self.getSymbolSize()) + + def __getitem__(self, item): + """Compatibility with PyMca and silx <= 0.4.0""" + if isinstance(item, slice): + return [self[index] for index in range(*item.indices(5))] + elif item == 0: + return self.getXData(copy=False) + elif item == 1: + return self.getYData(copy=False) + elif item == 2: + return self.getLegend() + elif item == 3: + info = self.getInfo(copy=False) + return {} if info is None else info + elif item == 4: + params = { + 'info': self.getInfo(), + 'color': self.getColor(), + 'symbol': self.getSymbol(), + 'linewidth': self.getLineWidth(), + 'linestyle': self.getLineStyle(), + 'xlabel': self.getXLabel(), + 'ylabel': self.getYLabel(), + 'yaxis': self.getYAxis(), + 'xerror': self.getXErrorData(copy=False), + 'yerror': self.getYErrorData(copy=False), + 'z': self.getZValue(), + 'selectable': self.isSelectable(), + 'fill': self.isFill() + } + return params + else: + raise IndexError("Index out of range: %s", str(item)) + + def setVisible(self, visible): + """Set visibility of item. + + :param bool visible: True to display it, False otherwise + """ + visibleChanged = self.isVisible() != bool(visible) + super(Curve, self).setVisible(visible) + + # TODO hackish data range implementation + if visibleChanged: + plot = self.getPlot() + if plot is not None: + plot._invalidateDataRange() + + def isHighlighted(self): + """Returns True if curve is highlighted. + + :rtype: bool + """ + return self._highlighted + + def setHighlighted(self, highlighted): + """Set the highlight state of the curve + + :param bool highlighted: + """ + highlighted = bool(highlighted) + if highlighted != self._highlighted: + self._highlighted = highlighted + # TODO inefficient: better to use backend's setCurveColor + self._updated() + + def getHighlightedColor(self): + """Returns the RGBA highlight color of the item + + :rtype: 4-tuple of int in [0, 255] + """ + return self._highlightColor + + def setHighlightedColor(self, color): + """Set the color to use when highlighted + + :param color: color(s) to be used for highlight + :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or + one of the predefined color names defined in Colors.py + """ + color = Colors.rgba(color) + if color != self._highlightColor: + self._highlightColor = color + self._updated() + + def getCurrentColor(self): + """Returns the current color of the curve. + + This color is either the color of the curve or the highlighted color, + depending on the highlight state. + + :rtype: 4-tuple of int in [0, 255] + """ + if self.isHighlighted(): + return self.getHighlightedColor() + else: + return self.getColor() diff --git a/silx/gui/plot/items/histogram.py b/silx/gui/plot/items/histogram.py new file mode 100644 index 0000000..c3821bc --- /dev/null +++ b/silx/gui/plot/items/histogram.py @@ -0,0 +1,288 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides the :class:`Histogram` item of the :class:`Plot`. +""" + +__authors__ = ["H. Payno", "T. Vincent"] +__license__ = "MIT" +__date__ = "02/05/2017" + + +import logging + +import numpy + +from .core import (Item, AlphaMixIn, ColorMixIn, FillMixIn, + LineMixIn, YAxisMixIn) + + +_logger = logging.getLogger(__name__) + + +def _computeEdges(x, histogramType): + """Compute the edges from a set of xs and a rule to generate the edges + + :param x: the x value of the curve to transform into an histogram + :param histogramType: the type of histogram we wan't to generate. + This define the way to center the histogram values compared to the + curve value. Possible values can be:: + + - 'left' + - 'right' + - 'center' + + :return: the edges for the given x and the histogramType + """ + # for now we consider that the spaces between xs are constant + edges = x.copy() + if histogramType is 'left': + width = 1 + if len(x) > 1: + width = x[1] - x[0] + edges = numpy.append(x[0] - width, edges) + if histogramType is 'center': + edges = _computeEdges(edges, 'right') + widths = (edges[1:] - edges[0:-1]) / 2.0 + widths = numpy.append(widths, widths[-1]) + edges = edges - widths + if histogramType is 'right': + width = 1 + if len(x) > 1: + width = x[-1] - x[-2] + edges = numpy.append(edges, x[-1] + width) + + return edges + + +def _getHistogramCurve(histogram, edges): + """Returns the x and y value of a curve corresponding to the histogram + + :param numpy.ndarray histogram: The values of the histogram + :param numpy.ndarray edges: The bin edges of the histogram + :return: a tuple(x, y) which contains the value of the curve to use + to display the histogram + """ + assert len(histogram) + 1 == len(edges) + x = numpy.empty(len(histogram) * 2, dtype=edges.dtype) + y = numpy.empty(len(histogram) * 2, dtype=histogram.dtype) + # Make a curve with stairs + x[:-1:2] = edges[:-1] + x[1::2] = edges[1:] + y[:-1:2] = histogram + y[1::2] = histogram + + return x, y + + +# TODO: Yerror, test log scale +class Histogram(Item, AlphaMixIn, ColorMixIn, FillMixIn, + LineMixIn, YAxisMixIn): + """Description of an histogram""" + + _DEFAULT_Z_LAYER = 1 + """Default overlay layer for histograms""" + + _DEFAULT_SELECTABLE = False + """Default selectable state for histograms""" + + _DEFAULT_LINEWIDTH = 1. + """Default line width of the histogram""" + + _DEFAULT_LINESTYLE = '-' + """Default line style of the histogram""" + + def __init__(self): + Item.__init__(self) + AlphaMixIn.__init__(self) + ColorMixIn.__init__(self) + FillMixIn.__init__(self) + LineMixIn.__init__(self) + YAxisMixIn.__init__(self) + + self._histogram = () + self._edges = () + + def _addBackendRenderer(self, backend): + """Update backend renderer""" + values, edges = self.getData(copy=False) + + if values.size == 0: + return None # No data to display, do not add renderer + + if values.size == 0: + return None # No data to display, do not add renderer to backend + + x, y = _getHistogramCurve(values, edges) + + # Filter-out values <= 0 + plot = self.getPlot() + if plot is not None: + xPositive = plot.isXAxisLogarithmic() + yPositive = plot.isYAxisLogarithmic() + else: + xPositive = False + yPositive = False + + if xPositive or yPositive: + clipped = numpy.logical_or( + (x <= 0) if xPositive else False, + (y <= 0) if yPositive else False) + # Make a copy and replace negative points by NaN + x = numpy.array(x, dtype=numpy.float) + y = numpy.array(y, dtype=numpy.float) + x[clipped] = numpy.nan + y[clipped] = numpy.nan + + return backend.addCurve(x, y, self.getLegend(), + color=self.getColor(), + symbol='', + linestyle=self.getLineStyle(), + linewidth=self.getLineWidth(), + yaxis=self.getYAxis(), + xerror=None, + yerror=None, + z=self.getZValue(), + selectable=self.isSelectable(), + fill=self.isFill(), + alpha=self.getAlpha(), + symbolsize=1) + + def _getBounds(self): + values, edges = self.getData(copy=False) + + plot = self.getPlot() + if plot is not None: + xPositive = plot.isXAxisLogarithmic() + yPositive = plot.isYAxisLogarithmic() + else: + xPositive = False + yPositive = False + + if xPositive or yPositive: + values = numpy.array(values, copy=True, dtype=numpy.float) + + if xPositive: + # Replace edges <= 0 by NaN and corresponding values by NaN + clipped = (edges <= 0) + edges = numpy.array(edges, copy=True, dtype=numpy.float) + edges[clipped] = numpy.nan + values[numpy.logical_or(clipped[:-1], clipped[1:])] = numpy.nan + + if yPositive: + # Replace values <= 0 by NaN, do not modify edges + values[values <= 0] = numpy.nan + + if xPositive or yPositive: + return (numpy.nanmin(edges), + numpy.nanmax(edges), + numpy.nanmin(values), + numpy.nanmax(values)) + + else: # No log scale, include 0 in bounds + return (numpy.nanmin(edges), + numpy.nanmax(edges), + min(0, numpy.nanmin(values)), + max(0, numpy.nanmax(values))) + + def setVisible(self, visible): + """Set visibility of item. + + :param bool visible: True to display it, False otherwise + """ + visibleChanged = self.isVisible() != bool(visible) + super(Histogram, self).setVisible(visible) + + # TODO hackish data range implementation + if visibleChanged: + plot = self.getPlot() + if plot is not None: + plot._invalidateDataRange() + + def getValueData(self, copy=True): + """The values of the histogram + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :returns: The bin edges of the histogram + :rtype: numpy.ndarray + """ + return numpy.array(self._histogram, copy=copy) + + def getBinEdgesData(self, copy=True): + """The bin edges of the histogram (number of histogram values + 1) + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :returns: The bin edges of the histogram + :rtype: numpy.ndarray + """ + return numpy.array(self._edges, copy=copy) + + def getData(self, copy=True): + """Return the histogram values and the bin edges + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :returns: (N histogram value, N+1 bin edges) + :rtype: 2-tuple of numpy.nadarray + """ + return (self.getValueData(copy), self.getBinEdgesData(copy)) + + def setData(self, histogram, edges, align='center', copy=True): + """Set the histogram values and bin edges. + + :param numpy.ndarray histogram: The values of the histogram. + :param numpy.ndarray edges: + The bin edges of the histogram. + If histogram and edges have the same length, the bin edges + are computed according to the align parameter. + :param str align: + In case histogram values and edges have the same length N, + the N+1 bin edges are computed according to the alignment in: + 'center' (default), 'left', 'right'. + :param bool copy: True make a copy of the data (default), + False to use provided arrays. + """ + histogram = numpy.array(histogram, copy=copy) + edges = numpy.array(edges, copy=copy) + + assert histogram.ndim == 1 + assert edges.ndim == 1 + assert edges.size in (histogram.size, histogram.size + 1) + assert align in ('center', 'left', 'right') + + if histogram.size == 0: # No data + self._histogram = () + self._edges = () + else: + if edges.size == histogram.size: # Compute true bin edges + edges = _computeEdges(edges, align) + + # Check that bin edges are monotonic + edgesDiff = numpy.diff(edges) + assert numpy.all(edgesDiff >= 0) or numpy.all(edgesDiff <= 0) + + self._histogram = histogram + self._edges = edges diff --git a/silx/gui/plot/items/image.py b/silx/gui/plot/items/image.py new file mode 100644 index 0000000..7e1dd8b --- /dev/null +++ b/silx/gui/plot/items/image.py @@ -0,0 +1,385 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides the :class:`ImageData` and :class:`ImageRgba` items +of the :class:`Plot`. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "06/03/2017" + + +from collections import Sequence +import logging + +import numpy + +from .core import Item, LabelsMixIn, DraggableMixIn, ColormapMixIn, AlphaMixIn +from ..Colors import applyColormapToData + + +_logger = logging.getLogger(__name__) + + +def _convertImageToRgba32(image, copy=True): + """Convert an RGB or RGBA image to RGBA32. + + It converts from floats in [0, 1], bool, integer and uint in [0, 255] + + If the input image is already an RGBA32 image, + the returned image shares the same data. + + :param image: Image to convert to + :type image: numpy.ndarray with 3 dimensions: height, width, color channels + :param bool copy: True (Default) to get a copy, False, avoid copy if possible + :return: The image converted to RGBA32 with dimension: (height, width, 4) + :rtype: numpy.ndarray of uint8 + """ + assert image.ndim == 3 + assert image.shape[-1] in (3, 4) + + # Convert type to uint8 + if image.dtype.name != 'uin8': + if image.dtype.kind == 'f': # Float in [0, 1] + image = (numpy.clip(image, 0., 1.) * 255).astype(numpy.uint8) + elif image.dtype.kind == 'b': # boolean + image = image.astype(numpy.uint8) * 255 + elif image.dtype.kind in ('i', 'u'): # int, uint + image = numpy.clip(image, 0, 255).astype(numpy.uint8) + else: + raise ValueError('Unsupported image dtype: %s', image.dtype.name) + copy = False # A copy as already been done, avoid next one + + # Convert RGB to RGBA + if image.shape[-1] == 3: + new_image = numpy.empty((image.shape[0], image.shape[1], 4), + dtype=numpy.uint8) + new_image[:, :, :3] = image + new_image[:, :, 3] = 255 + return new_image # This is a copy anyway + else: + return numpy.array(image, copy=copy) + + +class ImageBase(Item, LabelsMixIn, DraggableMixIn, AlphaMixIn): + """Description of an image""" + + def __init__(self): + Item.__init__(self) + LabelsMixIn.__init__(self) + DraggableMixIn.__init__(self) + AlphaMixIn.__init__(self) + self._data = numpy.zeros((0, 0, 4), dtype=numpy.uint8) + + self._origin = (0., 0.) + self._scale = (1., 1.) + + def __getitem__(self, item): + """Compatibility with PyMca and silx <= 0.4.0""" + if isinstance(item, slice): + return [self[index] for index in range(*item.indices(5))] + elif item == 0: + return self.getData(copy=False) + elif item == 1: + return self.getLegend() + elif item == 2: + info = self.getInfo(copy=False) + return {} if info is None else info + elif item == 3: + return None + elif item == 4: + params = { + 'info': self.getInfo(), + 'origin': self.getOrigin(), + 'scale': self.getScale(), + 'z': self.getZValue(), + 'selectable': self.isSelectable(), + 'draggable': self.isDraggable(), + 'colormap': None, + 'xlabel': self.getXLabel(), + 'ylabel': self.getYLabel(), + } + return params + else: + raise IndexError("Index out of range: %s" % str(item)) + + def setVisible(self, visible): + """Set visibility of item. + + :param bool visible: True to display it, False otherwise + """ + visibleChanged = self.isVisible() != bool(visible) + super(ImageBase, self).setVisible(visible) + + # TODO hackish data range implementation + if visibleChanged: + plot = self.getPlot() + if plot is not None: + plot._invalidateDataRange() + + def _getBounds(self): + if self.getData(copy=False).size == 0: # Empty data + return None + + height, width = self.getData(copy=False).shape[:2] + origin = self.getOrigin() + scale = self.getScale() + # Taking care of scale might be < 0 + xmin, xmax = origin[0], origin[0] + width * scale[0] + if xmin > xmax: + xmin, xmax = xmax, xmin + # Taking care of scale might be < 0 + ymin, ymax = origin[1], origin[1] + height * scale[1] + if ymin > ymax: + ymin, ymax = ymax, ymin + + plot = self.getPlot() + if (plot is not None and + plot.isXAxisLogarithmic() or plot.isYAxisLogarithmic()): + return None + else: + return xmin, xmax, ymin, ymax + + def getData(self, copy=True): + """Returns the image data + + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: numpy.ndarray + """ + return numpy.array(self._data, copy=copy) + + def getRgbaImageData(self, copy=True): + """Get the displayed RGB(A) image + + :returns: numpy.ndarray of uint8 of shape (height, width, 4) + """ + raise NotImplementedError('This MUST be implemented in sub-class') + + def getOrigin(self): + """Returns the offset from origin at which to display the image. + + :rtype: 2-tuple of float + """ + return self._origin + + def setOrigin(self, origin): + """Set the offset from origin at which to display the image. + + :param origin: (ox, oy) Offset from origin + :type origin: float or 2-tuple of float + """ + if isinstance(origin, Sequence): + origin = float(origin[0]), float(origin[1]) + else: # single value origin + origin = float(origin), float(origin) + if origin != self._origin: + self._origin = origin + self._updated() + + # TODO hackish data range implementation + if self.isVisible(): + plot = self.getPlot() + if plot is not None: + plot._invalidateDataRange() + + def getScale(self): + """Returns the scale of the image in data coordinates. + + :rtype: 2-tuple of float + """ + return self._scale + + def setScale(self, scale): + """Set the scale of the image + + :param scale: (sx, sy) Scale of the image + :type scale: float or 2-tuple of float + """ + if isinstance(scale, Sequence): + scale = float(scale[0]), float(scale[1]) + else: # single value scale + scale = float(scale), float(scale) + if scale != self._scale: + self._scale = scale + self._updated() + + +class ImageData(ImageBase, ColormapMixIn): + """Description of a data image with a colormap""" + + def __init__(self): + ImageBase.__init__(self) + ColormapMixIn.__init__(self) + self._data = numpy.zeros((0, 0), dtype=numpy.float32) + self._alternativeImage = None + + def _addBackendRenderer(self, backend): + """Update backend renderer""" + plot = self.getPlot() + assert plot is not None + if plot.isXAxisLogarithmic() or plot.isYAxisLogarithmic(): + return None # Do not render with log scales + + if self.getAlternativeImageData(copy=False) is not None: + dataToUse = self.getAlternativeImageData(copy=False) + else: + dataToUse = self.getData(copy=False) + + if dataToUse.size == 0: + return None # No data to display + + return backend.addImage(dataToUse, + legend=self.getLegend(), + origin=self.getOrigin(), + scale=self.getScale(), + z=self.getZValue(), + selectable=self.isSelectable(), + draggable=self.isDraggable(), + colormap=self.getColormap(), + alpha=self.getAlpha()) + + def __getitem__(self, item): + """Compatibility with PyMca and silx <= 0.4.0""" + if item == 3: + return self.getAlternativeImageData(copy=False) + + params = ImageBase.__getitem__(self, item) + if item == 4: + params['colormap'] = self.getColormap() + + return params + + def getRgbaImageData(self, copy=True): + """Get the displayed RGB(A) image + + :returns: numpy.ndarray of uint8 of shape (height, width, 4) + """ + if self._alternativeImage is not None: + return _convertImageToRgba32( + self.getAlternativeImageData(copy=False), copy=copy) + else: + # Apply colormap, in this case an new array is always returned + colormap = self.getColormap() + image = applyColormapToData(self.getData(copy=False), + **colormap) + return image + + def getAlternativeImageData(self, copy=True): + """Get the optional RGBA image that is displayed instead of the data + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :returns: None or numpy.ndarray + :rtype: numpy.ndarray or None + """ + if self._alternativeImage is None: + return None + else: + return numpy.array(self._alternativeImage, copy=copy) + + def setData(self, data, alternative=None, copy=True): + """"Set the image data and optionally an alternative RGB(A) representation + + :param numpy.ndarray data: Data array with 2 dimensions (h, w) + :param alternative: RGB(A) image to display instead of data, + shape: (h, w, 3 or 4) + :type alternative: None or numpy.ndarray + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + """ + data = numpy.array(data, copy=copy) + assert data.ndim == 2 + self._data = data + + if alternative is not None: + alternative = numpy.array(alternative, copy=copy) + assert alternative.ndim == 3 + assert alternative.shape[2] in (3, 4) + assert alternative.shape[:2] == data.shape[:2] + self._alternativeImage = alternative + self._updated() + + # TODO hackish data range implementation + if self.isVisible(): + plot = self.getPlot() + if plot is not None: + plot._invalidateDataRange() + + +class ImageRgba(ImageBase): + """Description of an RGB(A) image""" + + def __init__(self): + ImageBase.__init__(self) + + def _addBackendRenderer(self, backend): + """Update backend renderer""" + plot = self.getPlot() + assert plot is not None + if plot.isXAxisLogarithmic() or plot.isYAxisLogarithmic(): + return None # Do not render with log scales + + data = self.getData(copy=False) + + if data.size == 0: + return None # No data to display + + return backend.addImage(data, + legend=self.getLegend(), + origin=self.getOrigin(), + scale=self.getScale(), + z=self.getZValue(), + selectable=self.isSelectable(), + draggable=self.isDraggable(), + colormap=None, + alpha=self.getAlpha()) + + def getRgbaImageData(self, copy=True): + """Get the displayed RGB(A) image + + :returns: numpy.ndarray of uint8 of shape (height, width, 4) + """ + return _convertImageToRgba32(self.getData(copy=False), copy=copy) + + def setData(self, data, copy=True): + """Set the image data + + :param data: RGB(A) image data to set + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + """ + data = numpy.array(data, copy=copy) + assert data.ndim == 3 + assert data.shape[-1] in (3, 4) + self._data = data + + self._updated() + + # TODO hackish data range implementation + if self.isVisible(): + plot = self.getPlot() + if plot is not None: + plot._invalidateDataRange() diff --git a/silx/gui/plot/items/marker.py b/silx/gui/plot/items/marker.py new file mode 100644 index 0000000..c05558b --- /dev/null +++ b/silx/gui/plot/items/marker.py @@ -0,0 +1,241 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides markers item of the :class:`Plot`. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "06/03/2017" + + +import logging + +from .core import Item, DraggableMixIn, ColorMixIn, SymbolMixIn + + +_logger = logging.getLogger(__name__) + + +class _BaseMarker(Item, DraggableMixIn, ColorMixIn): + """Base class for markers""" + + _DEFAULT_COLOR = (0., 0., 0., 1.) + """Default color of the markers""" + + def __init__(self): + Item.__init__(self) + DraggableMixIn.__init__(self) + ColorMixIn.__init__(self) + + self._text = '' + self._x = None + self._y = None + self._constraint = self._defaultConstraint + + def _addBackendRenderer(self, backend): + """Update backend renderer""" + # TODO not very nice way to do it, but simple + symbol = self.getSymbol() if isinstance(self, Marker) else None + + return backend.addMarker( + x=self.getXPosition(), + y=self.getYPosition(), + legend=self.getLegend(), + text=self.getText(), + color=self.getColor(), + selectable=self.isSelectable(), + draggable=self.isDraggable(), + symbol=symbol, + constraint=self.getConstraint(), + overlay=self.isOverlay()) + + def isOverlay(self): + """Return true if marker is drawn as an overlay. + + A marker is an overlay if it is draggable. + + :rtype: bool + """ + return self.isDraggable() + + def getText(self): + """Returns marker text. + + :rtype: str + """ + return self._text + + def setText(self, text): + """Set the text of the marker. + + :param str text: The text to use + """ + text = str(text) + if text != self._text: + self._text = text + self._updated() + + def getXPosition(self): + """Returns the X position of the marker line in data coordinates + + :rtype: float or None + """ + return self._x + + def getYPosition(self): + """Returns the Y position of the marker line in data coordinates + + :rtype: float or None + """ + return self._y + + def getPosition(self): + """Returns the (x, y) position of the marker in data coordinates + + :rtype: 2-tuple of float or None + """ + return self._x, self._y + + def setPosition(self, x, y): + """Set marker position in data coordinates + + Constraint are applied if any. + + :param float x: X coordinates in data frame + :param float y: Y coordinates in data frame + """ + x, y = self.getConstraint()(x, y) + x, y = float(x), float(y) + if x != self._x or y != self._y: + self._x, self._y = x, y + self._updated() + + def getConstraint(self): + """Returns the dragging constraint of this item""" + return self._constraint + + def _setConstraint(self, constraint): # TODO support update + """Set the constraint. + + This is private for now as update is not handled. + + :param callable constraint: + :param constraint: A function filtering item displacement by + dragging operations or None for no filter. + This function is called each time the item is + moved. + This is only used if isDraggable returns True. + :type constraint: None or a callable that takes the coordinates of + the current cursor position in the plot as input + and that returns the filtered coordinates. + """ + if constraint is None: + constraint = self._defaultConstraint + assert callable(constraint) + self._constraint = constraint + + @staticmethod + def _defaultConstraint(*args): + """Default constraint not doing anything""" + return args + + +class Marker(_BaseMarker, SymbolMixIn): + """Description of a marker""" + + _DEFAULT_SYMBOL = '+' + """Default symbol of the marker""" + + def __init__(self): + _BaseMarker.__init__(self) + SymbolMixIn.__init__(self) + + self._x = 0. + self._y = 0. + + def _setConstraint(self, constraint): + """Set the constraint function of the marker drag. + + It also supports 'horizontal' and 'vertical' str as constraint. + + :param constraint: The constraint of the dragging of this marker + :type: constraint: callable or str + """ + if constraint == 'horizontal': + constraint = self._horizontalConstraint + elif constraint == 'vertical': + constraint = self._verticalConstraint + + super(Marker, self)._setConstraint(constraint) + + def _horizontalConstraint(self, _, y): + return self.getXPosition(), y + + def _verticalConstraint(self, x, _): + return x, self.getYPosition() + + +class XMarker(_BaseMarker): + """Description of a marker""" + + def __init__(self): + _BaseMarker.__init__(self) + self._x = 0. + + def setPosition(self, x, y): + """Set marker line position in data coordinates + + Constraint are applied if any. + + :param float x: X coordinates in data frame + :param float y: Y coordinates in data frame + """ + x, _ = self.getConstraint()(x, y) + x = float(x) + if x != self._x: + self._x = x + self._updated() + + +class YMarker(_BaseMarker): + """Description of a marker""" + + def __init__(self): + _BaseMarker.__init__(self) + self._y = 0. + + def setPosition(self, x, y): + """Set marker line position in data coordinates + + Constraint are applied if any. + + :param float x: X coordinates in data frame + :param float y: Y coordinates in data frame + """ + _, y = self.getConstraint()(x, y) + y = float(y) + if y != self._y: + self._y = y + self._updated() diff --git a/silx/gui/plot/items/scatter.py b/silx/gui/plot/items/scatter.py new file mode 100644 index 0000000..3897dc1 --- /dev/null +++ b/silx/gui/plot/items/scatter.py @@ -0,0 +1,169 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides the :class:`Scatter` item of the :class:`Plot`. +""" + +__authors__ = ["T. Vincent", "P. Knobel"] +__license__ = "MIT" +__date__ = "29/03/2017" + + +import logging + +import numpy + +from .core import Points, ColormapMixIn +from silx.gui.plot.Colors import applyColormapToData # TODO: cherry-pick commit or wait for PR merge + +_logger = logging.getLogger(__name__) + + +class Scatter(Points, ColormapMixIn): + """Description of a scatter""" + _DEFAULT_SYMBOL = 'o' + """Default symbol of the scatter plots""" + + def __init__(self): + Points.__init__(self) + ColormapMixIn.__init__(self) + self._value = () + + def _addBackendRenderer(self, backend): + """Update backend renderer""" + # Filter-out values <= 0 + xFiltered, yFiltered, valueFiltered, xerror, yerror = self.getData( + copy=False, displayed=True) + + if len(xFiltered) == 0: + return None # No data to display, do not add renderer to backend + + cmap = self.getColormap() + rgbacolors = applyColormapToData(self._value, + cmap["name"], + cmap["normalization"], + cmap["autoscale"], + cmap["vmin"], + cmap["vmax"], + cmap.get("colors")) + + return backend.addCurve(xFiltered, yFiltered, self.getLegend(), + color=rgbacolors, + symbol=self.getSymbol(), + linewidth=0, + linestyle="", + yaxis='left', + xerror=xerror, + yerror=yerror, + z=self.getZValue(), + selectable=self.isSelectable(), + fill=False, + alpha=self.getAlpha(), + symbolsize=self.getSymbolSize()) + + def _logFilterData(self, xPositive, yPositive): + """Filter out values with x or y <= 0 on log axes + + :param bool xPositive: True to filter arrays according to X coords. + :param bool yPositive: True to filter arrays according to Y coords. + :return: The filtered arrays or unchanged object if not filtering needed + :rtype: (x, y, value, xerror, yerror) + """ + # overloaded from Points to filter also value. + value = self.getValueData(copy=False) + + if xPositive or yPositive: + clipped = self._getClippingBoolArray(xPositive, yPositive) + + if numpy.any(clipped): + # copy to keep original array and convert to float + value = numpy.array(value, copy=True, dtype=numpy.float) + value[clipped] = numpy.nan + + x, y, xerror, yerror = Points._logFilterData(self, xPositive, yPositive) + + return x, y, value, xerror, yerror + + def getValueData(self, copy=True): + """Returns the value assigned to the scatter data points. + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :rtype: numpy.ndarray + """ + return numpy.array(self._value, copy=copy) + + def getData(self, copy=True, displayed=False): + """Returns the x, y coordinates and the value of the data points + + :param copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :param bool displayed: True to only get curve points that are displayed + in the plot. Default: False. + Note: If plot has log scale, negative points + are not displayed. + :returns: (x, y, value, xerror, yerror) + :rtype: 5-tuple of numpy.ndarray + """ + if displayed: + data = self._getCachedData() + if data is not None: + assert len(data) == 5 + return data + + return (self.getXData(copy), + self.getYData(copy), + self.getValueData(copy), + self.getXErrorData(copy), + self.getYErrorData(copy)) + + # reimplemented from Points to handle `value` + def setData(self, x, y, value, xerror=None, yerror=None, copy=True): + """Set the data of the scatter. + + :param numpy.ndarray x: The data corresponding to the x coordinates. + :param numpy.ndarray y: The data corresponding to the y coordinates. + :param numpy.ndarray value: The data corresponding to the value of + the data points. + :param xerror: Values with the uncertainties on the x values + :type xerror: A float, or a numpy.ndarray of float32. + If it is an array, it can either be a 1D array of + same length as the data or a 2D array with 2 rows + of same length as the data: row 0 for positive errors, + row 1 for negative errors. + :param yerror: Values with the uncertainties on the y values + :type yerror: A float, or a numpy.ndarray of float32. See xerror. + :param bool copy: True make a copy of the data (default), + False to use provided arrays. + """ + value = numpy.array(value, copy=copy) + assert value.ndim == 1 + assert len(x) == len(value) + + self._value = value + + # set x, y, xerror, yerror + + # call self._updated + plot._invalidateDataRange() + Points.setData(self, x, y, xerror, yerror, copy) diff --git a/silx/gui/plot/items/shape.py b/silx/gui/plot/items/shape.py new file mode 100644 index 0000000..b663989 --- /dev/null +++ b/silx/gui/plot/items/shape.py @@ -0,0 +1,121 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides the :class:`Shape` item of the :class:`Plot`. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "06/03/2017" + + +import logging + +import numpy + +from .core import Item, ColorMixIn, FillMixIn + + +_logger = logging.getLogger(__name__) + + +# TODO probably make one class for each kind of shape +# TODO check fill:polygon/polyline + fill = duplicated +class Shape(Item, ColorMixIn, FillMixIn): + """Description of a shape item + + :param str type_: The type of shape in: + 'hline', 'polygon', 'rectangle', 'vline', 'polyline' + """ + + def __init__(self, type_): + Item.__init__(self) + ColorMixIn.__init__(self) + FillMixIn.__init__(self) + self._overlay = False + assert type_ in ('hline', 'polygon', 'rectangle', 'vline', 'polyline') + self._type = type_ + self._points = () + + self._handle = None + + def _addBackendRenderer(self, backend): + """Update backend renderer""" + points = self.getPoints(copy=False) + x, y = points.T[0], points.T[1] + return backend.addItem(x, + y, + legend=self.getLegend(), + shape=self.getType(), + color=self.getColor(), + fill=self.isFill(), + overlay=self.isOverlay(), + z=self.getZValue()) + + def isOverlay(self): + """Return true if shape is drawn as an overlay + + :rtype: bool + """ + return self._overlay + + def setOverlay(self, overlay): + """Set the overlay state of the shape + + :param bool overlay: True to make it an overlay + """ + overlay = bool(overlay) + if overlay != self._overlay: + self._overlay = overlay + self._updated() + + def getType(self): + """Returns the type of shape to draw. + + One of: 'hline', 'polygon', 'rectangle', 'vline', 'polyline' + + :rtype: str + """ + return self._type + + def getPoints(self, copy=True): + """Get the control points of the shape. + + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :return: Array of point coordinates + :rtype: numpy.ndarray with 2 dimensions + """ + return numpy.array(self._points, copy=copy) + + def setPoints(self, points, copy=True): + """Set the point coordinates + + :param numpy.ndarray points: Array of point coordinates + :param bool copy: True (Default) to get a copy, + False to use internal representation (do not modify!) + :return: + """ + self._points = numpy.array(points, copy=copy) + self._updated() diff --git a/silx/gui/plot/setup.py b/silx/gui/plot/setup.py new file mode 100644 index 0000000..6408113 --- /dev/null +++ b/silx/gui/plot/setup.py @@ -0,0 +1,47 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "16/02/2016" + + +from numpy.distutils.misc_util import Configuration + + +def configuration(parent_package='', top_path=None): + config = Configuration('plot', parent_package, top_path) + config.add_subpackage('_utils') + config.add_subpackage('backends') + config.add_subpackage('backends.glutils') + config.add_subpackage('items') + config.add_subpackage('test') + + return config + + +if __name__ == "__main__": + from numpy.distutils.core import setup + + setup(configuration=configuration) diff --git a/silx/gui/plot/test/__init__.py b/silx/gui/plot/test/__init__.py new file mode 100644 index 0000000..b4378c7 --- /dev/null +++ b/silx/gui/plot/test/__init__.py @@ -0,0 +1,71 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "18/02/2016" + + +import unittest + +from .._utils.test import suite as testUtilsSuite +from .testColorBar import suite as testColorBarSuite +from .testColormapDialog import suite as testColormapDialogSuite +from .testColors import suite as testColorsSuite +from .testCurvesROIWidget import suite as testCurvesROIWidgetSuite +from .testAlphaSlider import suite as testAlphaSliderSuite +from .testInteraction import suite as testInteractionSuite +from .testLegendSelector import suite as testLegendSelectorSuite +from .testMaskToolsWidget import suite as testMaskToolsWidgetSuite +from .testScatterMaskToolsWidget import suite as testScatterMaskToolsWidgetSuite +from .testPlotInteraction import suite as testPlotInteractionSuite +from .testPlotTools import suite as testPlotToolsSuite +from .testPlotWidget import suite as testPlotWidgetSuite +from .testPlotWindow import suite as testPlotWindowSuite +from .testPlot import suite as testPlotSuite +from .testProfile import suite as testProfileSuite +from .testStackView import suite as testStackViewSuite + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTests( + [testUtilsSuite(), + testColorBarSuite(), + testColorsSuite(), + testColormapDialogSuite(), + testCurvesROIWidgetSuite(), + testAlphaSliderSuite(), + testInteractionSuite(), + testLegendSelectorSuite(), + testMaskToolsWidgetSuite(), + testScatterMaskToolsWidgetSuite(), + testPlotInteractionSuite(), + testPlotSuite(), + testPlotToolsSuite(), + testPlotWidgetSuite(), + testPlotWindowSuite(), + testProfileSuite(), + testStackViewSuite()]) + return test_suite diff --git a/silx/gui/plot/test/testAlphaSlider.py b/silx/gui/plot/test/testAlphaSlider.py new file mode 100644 index 0000000..304a562 --- /dev/null +++ b/silx/gui/plot/test/testAlphaSlider.py @@ -0,0 +1,221 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Tests for ImageAlphaSlider""" + + +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "28/03/2017" + +import numpy +import unittest + +from silx.gui import qt +from silx.gui.test.utils import TestCaseQt +from silx.gui.plot import PlotWidget +from silx.gui.plot import AlphaSlider + +# Makes sure a QApplication exists +_qapp = qt.QApplication.instance() or qt.QApplication([]) + + +class TestActiveImageAlphaSlider(TestCaseQt): + def setUp(self): + super(TestActiveImageAlphaSlider, self).setUp() + self.plot = PlotWidget() + self.aslider = AlphaSlider.ActiveImageAlphaSlider(plot=self.plot) + self.aslider.setOrientation(qt.Qt.Horizontal) + + toolbar = qt.QToolBar("plot", self.plot) + toolbar.addWidget(self.aslider) + self.plot.addToolBar(toolbar) + + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + self.mouseMove(self.plot) # Move to center + self.qapp.processEvents() + + def tearDown(self): + self.qapp.processEvents() + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + del self.aslider + + super(TestActiveImageAlphaSlider, self).tearDown() + + def testWidgetEnabled(self): + # no active image initially, slider must be deactivate + self.assertFalse(self.aslider.isEnabled()) + + self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]])) + # now we have an active image + self.assertTrue(self.aslider.isEnabled()) + + self.plot.setActiveImage(None) + self.assertFalse(self.aslider.isEnabled()) + + def testGetImage(self): + self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]])) + self.assertEqual(self.plot.getActiveImage(), + self.aslider.getItem()) + + self.plot.addImage(numpy.array([[0, 1, 3], [2, 4, 6]]), legend="2") + self.plot.setActiveImage("2") + self.assertEqual(self.plot.getImage("2"), + self.aslider.getItem()) + + def testGetAlpha(self): + self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1") + self.aslider.setValue(137) + self.assertAlmostEqual(self.aslider.getAlpha(), + 137. / 255) + + +class TestNamedImageAlphaSlider(TestCaseQt): + def setUp(self): + super(TestNamedImageAlphaSlider, self).setUp() + self.plot = PlotWidget() + self.aslider = AlphaSlider.NamedImageAlphaSlider(plot=self.plot) + self.aslider.setOrientation(qt.Qt.Horizontal) + + toolbar = qt.QToolBar("plot", self.plot) + toolbar.addWidget(self.aslider) + self.plot.addToolBar(toolbar) + + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + self.mouseMove(self.plot) # Move to center + self.qapp.processEvents() + + def tearDown(self): + self.qapp.processEvents() + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + del self.aslider + + super(TestNamedImageAlphaSlider, self).tearDown() + + def testWidgetEnabled(self): + # no image set initially, slider must be deactivate + self.assertFalse(self.aslider.isEnabled()) + + self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1") + self.aslider.setLegend("1") + # now we have an image set + self.assertTrue(self.aslider.isEnabled()) + + def testGetImage(self): + self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1") + self.plot.addImage(numpy.array([[0, 1, 3], [2, 4, 6]]), legend="2") + self.aslider.setLegend("1") + self.assertEqual(self.plot.getImage("1"), + self.aslider.getItem()) + + self.aslider.setLegend("2") + self.assertEqual(self.plot.getImage("2"), + self.aslider.getItem()) + + def testGetAlpha(self): + self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1") + self.aslider.setLegend("1") + self.aslider.setValue(128) + self.assertAlmostEqual(self.aslider.getAlpha(), + 128. / 255) + + +class TestNamedScatterAlphaSlider(TestCaseQt): + def setUp(self): + super(TestNamedScatterAlphaSlider, self).setUp() + self.plot = PlotWidget() + self.aslider = AlphaSlider.NamedScatterAlphaSlider(plot=self.plot) + self.aslider.setOrientation(qt.Qt.Horizontal) + + toolbar = qt.QToolBar("plot", self.plot) + toolbar.addWidget(self.aslider) + self.plot.addToolBar(toolbar) + + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + self.mouseMove(self.plot) # Move to center + self.qapp.processEvents() + + def tearDown(self): + self.qapp.processEvents() + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + del self.aslider + + super(TestNamedScatterAlphaSlider, self).tearDown() + + def testWidgetEnabled(self): + # no Scatter set initially, slider must be deactivate + self.assertFalse(self.aslider.isEnabled()) + + self.plot.addScatter([0, 1, 2], [2, 3, 4], [5, 6, 7], + legend="1") + self.aslider.setLegend("1") + # now we have an image set + self.assertTrue(self.aslider.isEnabled()) + + def testGetScatter(self): + self.plot.addScatter([0, 1, 2], [2, 3, 4], [5, 6, 7], + legend="1") + self.plot.addScatter([0, 10, 20], [20, 30, 40], [50, 60, 70], + legend="2") + self.aslider.setLegend("1") + self.assertEqual(self.plot.getScatter("1"), + self.aslider.getItem()) + + self.aslider.setLegend("2") + self.assertEqual(self.plot.getScatter("2"), + self.aslider.getItem()) + + def testGetAlpha(self): + self.plot.addScatter([0, 10, 20], [20, 30, 40], [50, 60, 70], + legend="1") + self.aslider.setLegend("1") + self.aslider.setValue(128) + self.assertAlmostEqual(self.aslider.getAlpha(), + 128. / 255) + + +def suite(): + test_suite = unittest.TestSuite() + # test_suite.addTest(positionInfoTestSuite) + for testClass in (TestActiveImageAlphaSlider, TestNamedImageAlphaSlider, + TestNamedScatterAlphaSlider): + test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase( + testClass)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot/test/testColorBar.py b/silx/gui/plot/test/testColorBar.py new file mode 100644 index 0000000..797ff03 --- /dev/null +++ b/silx/gui/plot/test/testColorBar.py @@ -0,0 +1,240 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for ColorBar featues and sub widgets of Colorbar module""" + +__authors__ = ["H. Payno"] +__license__ = "MIT" +__date__ = "11/04/2017" + +import unittest +from silx.gui.test.utils import TestCaseQt +from silx.gui.plot.ColorBar import _ColorScale +from silx.gui.plot.ColorBar import ColorBarWidget +from silx.gui.plot import Plot2D +import numpy + + +class TestColorScale(unittest.TestCase): + """Test that interaction with the colorScale is correct""" + def setUp(self): + self.colorScaleWidget = _ColorScale(colormap=None, parent=None) + + def tearDown(self): + self.colorScaleWidget.deleteLater() + self.colorScaleWidget = None + + def testRelativePositionLinear(self): + self.colorMapLin1 = { 'name': 'gray', 'normalization': 'linear', + 'autoscale': False, 'vmin': 0.0, 'vmax': 1.0 } + self.colorScaleWidget.setColormap(self.colorMapLin1) + + self.assertTrue( + self.colorScaleWidget.getValueFromRelativePosition(0.25) == 0.25) + self.assertTrue( + self.colorScaleWidget.getValueFromRelativePosition(0.5) == 0.5) + self.assertTrue( + self.colorScaleWidget.getValueFromRelativePosition(1.0) == 1.0) + + self.colorMapLin2 = { 'name': 'viridis', 'normalization': 'linear', + 'autoscale': False, 'vmin': -10, 'vmax': 0 } + self.colorScaleWidget.setColormap(self.colorMapLin2) + + self.assertTrue( + self.colorScaleWidget.getValueFromRelativePosition(0.25) == -7.5) + self.assertTrue( + self.colorScaleWidget.getValueFromRelativePosition(0.5) == -5.0) + self.assertTrue( + self.colorScaleWidget.getValueFromRelativePosition(1.0) == 0.0) + + def testRelativePositionLog(self): + self.colorMapLog1 = { 'name': 'temperature', 'normalization': 'log', + 'autoscale': False, 'vmin': 1.0, 'vmax': 100.0 } + + self.colorScaleWidget.setColormap(self.colorMapLog1) + + val = self.colorScaleWidget.getValueFromRelativePosition(1.0) + self.assertTrue(val == 100.0) + + val = self.colorScaleWidget.getValueFromRelativePosition(0.5) + self.assertTrue(val == 10.0) + + val = self.colorScaleWidget.getValueFromRelativePosition(0.0) + self.assertTrue(val == 1.0) + + def testNegativeLogMin(self): + colormap = { 'name': 'gray', 'normalization': 'log', + 'autoscale': False, 'vmin': -1.0, 'vmax': 1.0 } + + with self.assertRaises(ValueError): + self.colorScaleWidget.setColormap(colormap) + + def testNegativeLogMax(self): + colormap = { 'name': 'gray', 'normalization': 'log', + 'autoscale': False, 'vmin': 1.0, 'vmax': -1.0 } + + with self.assertRaises(ValueError): + self.colorScaleWidget.setColormap(colormap) + +class TestNoAutoscale(unittest.TestCase): + """Test that ticks and color displayed are correct in the case of a colormap + with no autoscale + """ + + def setUp(self): + self.plot = Plot2D() + self.colorBar = ColorBarWidget(parent=None, plot=self.plot) + self.tickBar = self.colorBar.getColorScaleBar().getTickBar() + self.colorScale = self.colorBar.getColorScaleBar().getColorScale() + + def tearDown(self): + self.tickBar = None + self.colorScale = None + del self.colorBar + self.plot.close() + del self.plot + + def testLogNormNoAutoscale(self): + colormapLog = { 'name': 'gray', 'normalization': 'log', + 'autoscale': False, 'vmin': 1.0, 'vmax': 100.0 } + + data = numpy.linspace(10, 1e10, 9).reshape(3, 3) + self.plot.addImage(data=data, colormap=colormapLog, legend='toto') + self.plot.setActiveImage('toto') + + # test Ticks + self.tickBar.setTicksNumber(10) + self.tickBar.computeTicks() + + ticksTh = numpy.linspace(1.0, 100.0, 10) + ticksTh = 10**ticksTh + numpy.array_equal(self.tickBar.ticks, ticksTh) + + # test ColorScale + val = self.colorScale.getValueFromRelativePosition(1.0) + self.assertTrue(val == 100.0) + + val = self.colorScale.getValueFromRelativePosition(0.0) + self.assertTrue(val == 1.0) + + def testLinearNormNoAutoscale(self): + colormapLog = { 'name': 'gray', 'normalization': 'linear', + 'autoscale': False, 'vmin': -4, 'vmax': 5 } + + data = numpy.linspace(1, 9, 9).reshape(3, 3) + self.plot.addImage(data=data, colormap=colormapLog, legend='toto') + self.plot.setActiveImage('toto') + + # test Ticks + self.tickBar.setTicksNumber(10) + self.tickBar.computeTicks() + + numpy.array_equal(self.tickBar.ticks, numpy.linspace(-4, 5, 10)) + + # test ColorScale + val = self.colorScale.getValueFromRelativePosition(1.0) + self.assertTrue(val == 5.0) + + val = self.colorScale.getValueFromRelativePosition(0.0) + self.assertTrue(val == -4.0) + +class TestColorbarWidget(TestCaseQt): + """Test interaction with the ColorScaleBar""" + + def setUp(self): + super(TestColorbarWidget, self).setUp() + self.plot = Plot2D() + self.colorBar = ColorBarWidget(parent=None, plot=self.plot) + + def tearDown(self): + del self.colorBar + self.plot.close() + del self.plot + + super(TestColorbarWidget, self).tearDown() + + def testEmptyColorBar(self): + colorBar = ColorBarWidget(parent=None) + colorBar.show() + self.qWaitForWindowExposed(colorBar) + + def testNegativeColormaps(self): + """test the behavior of the ColorBarWidget in the case of negative + values + + Note : colorbar is modified by the Plot directly not ColorBarWidget + """ + colormapLog = { 'name': 'gray', 'normalization': 'log', + 'autoscale': True, 'vmin': -1.0, 'vmax': 1.0 } + + colormapLog2 = { 'name': 'gray', 'normalization': 'log', + 'autoscale': False, 'vmin': -1.0, 'vmax': 1.0 } + + data = numpy.array([-5, -4, 0, 2, 3, 5, 10, 20, 30]) + data = data.reshape(3, 3) + self.plot.addImage(data=data, colormap=colormapLog, legend='toto') + self.plot.setActiveImage('toto') + + # default behavior when autoscale : set to minmal positive value + data[data<1] = data.max() + self.assertTrue(self.colorBar._colormap['vmin'] == data.min()) + self.assertTrue(self.colorBar._colormap['vmax'] == data.max()) + + data = numpy.linspace(-9, -2, 100).reshape(10, 10) + + self.plot.addImage(data=data, colormap=colormapLog2, legend='toto') + self.plot.setActiveImage('toto') + # if negative values, changing bounds for default : 1, 10 + self.assertTrue(self.colorBar._colormap['vmin'] == 1) + self.assertTrue(self.colorBar._colormap['vmax'] == 10) + + def testPlotAssocation(self): + """Make sure the ColorBarWidget is proparly connected with the plot""" + colormap = { 'name': 'gray', 'normalization': 'linear', + 'autoscale': True, 'vmin': -1.0, 'vmax': 1.0 } + + # make sure that default settings are the same + self.assertTrue( + self.colorBar.getColormap() == self.plot.getDefaultColormap()) + + data = numpy.linspace(0, 10, 100).reshape(10, 10) + self.plot.addImage(data=data, colormap=colormap, legend='toto') + self.plot.setActiveImage('toto') + + # make sure the modification of the colormap has been done + self.assertFalse( + self.colorBar.getColormap() == self.plot.getDefaultColormap()) + + +def suite(): + test_suite = unittest.TestSuite() + for ui in (TestColorScale, TestNoAutoscale, TestColorbarWidget): + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(ui)) + + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite')
\ No newline at end of file diff --git a/silx/gui/plot/test/testColormapDialog.py b/silx/gui/plot/test/testColormapDialog.py new file mode 100644 index 0000000..d016548 --- /dev/null +++ b/silx/gui/plot/test/testColormapDialog.py @@ -0,0 +1,68 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for ColormapDialog""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "05/12/2016" + + +import doctest +import unittest + +from silx.gui.test.utils import qWaitForWindowExposedAndActivate +from silx.gui import qt +from silx.gui.plot import ColormapDialog + + +# Makes sure a QApplication exists +_qapp = qt.QApplication.instance() or qt.QApplication([]) + + +def _tearDownQt(docTest): + """Tear down to use for test from docstring. + + Checks that dialog widget is displayed + """ + dialogWidget = docTest.globs['dialog'] + qWaitForWindowExposedAndActivate(dialogWidget) + dialogWidget.setAttribute(qt.Qt.WA_DeleteOnClose) + dialogWidget.close() + del dialogWidget + _qapp.processEvents() + + +cmapDocTestSuite = doctest.DocTestSuite(ColormapDialog, tearDown=_tearDownQt) +"""Test suite of tests from the module's docstrings.""" + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTest(cmapDocTestSuite) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot/test/testColors.py b/silx/gui/plot/test/testColors.py new file mode 100644 index 0000000..94c22f3 --- /dev/null +++ b/silx/gui/plot/test/testColors.py @@ -0,0 +1,94 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for Colors""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "05/12/2016" + + +import numpy + +import unittest +from silx.test.utils import ParametricTestCase + +from silx.gui.plot import Colors + + +class TestRGBA(ParametricTestCase): + """Basic tests of rgba function""" + + def testRGBA(self): + """"Test rgba function with accepted values""" + tests = { # name: (colors, expected values) + 'blue': ('blue', (0., 0., 1., 1.)), + '#010203': ('#010203', (1. / 255., 2. / 255., 3. / 255., 1.)), + '#01020304': ('#01020304', (1. / 255., 2. / 255., 3. / 255., 4. / 255.)), + '3 x uint8': (numpy.array((1, 255, 0), dtype=numpy.uint8), + (1 / 255., 1., 0., 1.)), + '4 x uint8': (numpy.array((1, 255, 0, 1), dtype=numpy.uint8), + (1 / 255., 1., 0., 1 / 255.)), + '3 x float overflow': ((3., 0.5, 1.), (1., 0.5, 1., 1.)), + } + + for name, test in tests.items(): + color, expected = test + with self.subTest(msg=name): + result = Colors.rgba(color) + self.assertEqual(result, expected) + + +class TestApplyColormapToData(ParametricTestCase): + """Tests of applyColormapToData function""" + + def testApplyColormapToData(self): + """Simple test of applyColormapToData function""" + colormap = dict(name='gray', normalization='linear', + autoscale=False, vmin=0, vmax=255) + + size = 10 + expected = numpy.empty((size, 4), dtype='uint8') + expected[:, 0] = numpy.arange(size, dtype='uint8') + expected[:, 1] = expected[:, 0] + expected[:, 2] = expected[:, 0] + expected[:, 3] = 255 + + for dtype in ('uint8', 'int32', 'float32', 'float64'): + with self.subTest(dtype=dtype): + array = numpy.arange(size, dtype=dtype) + result = Colors.applyColormapToData(array, **colormap) + self.assertTrue(numpy.all(numpy.equal(result, expected))) + + +def suite(): + test_suite = unittest.TestSuite() + for testClass in (TestRGBA, TestApplyColormapToData): + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(testClass)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot/test/testCurvesROIWidget.py b/silx/gui/plot/test/testCurvesROIWidget.py new file mode 100644 index 0000000..3c6f2ba --- /dev/null +++ b/silx/gui/plot/test/testCurvesROIWidget.py @@ -0,0 +1,153 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for CurvesROIWidget""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "05/12/2016" + + +import logging +import os.path +import unittest + +import numpy + +from silx.gui import qt +from silx.test.utils import temp_dir +from silx.gui.test.utils import TestCaseQt +from silx.gui.plot import PlotWindow, CurvesROIWidget + + +logging.basicConfig() +_logger = logging.getLogger(__name__) + + +class TestCurvesROIWidget(TestCaseQt): + """Basic test for CurvesROIWidget""" + + def setUp(self): + super(TestCurvesROIWidget, self).setUp() + self.plot = PlotWindow() + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + self.widget = CurvesROIWidget.CurvesROIDockWidget(plot=self.plot, name='TEST') + self.widget.show() + self.qWaitForWindowExposed(self.widget) + + def tearDown(self): + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + + self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) + self.widget.close() + del self.widget + + super(TestCurvesROIWidget, self).tearDown() + + def testEmptyPlot(self): + """Empty plot, display ROI widget""" + pass + + def testWithCurves(self): + """Plot with curves: test all ROI widget buttons""" + for offset in range(2): + self.plot.addCurve(numpy.arange(1000), + offset + numpy.random.random(1000), + legend=str(offset)) + + # Add two ROI + self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton) + self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton) + + # Change active curve + self.plot.setActiveCurve(str(1)) + + # Delete a ROI + self.mouseClick(self.widget.roiWidget.delButton, qt.Qt.LeftButton) + + with temp_dir() as tmpDir: + self.tmpFile = os.path.join(tmpDir, 'test.ini') + + # Save ROIs + self.widget.roiWidget.save(self.tmpFile) + self.assertTrue(os.path.isfile(self.tmpFile)) + + # Reset ROIs + self.mouseClick(self.widget.roiWidget.resetButton, + qt.Qt.LeftButton) + + # Load ROIs + self.widget.roiWidget.load(self.tmpFile) + + del self.tmpFile + + def testCalculation(self): + x = numpy.arange(100.) + y = numpy.arange(100.) + + # Add two curves + self.plot.addCurve(x, y, legend="positive") + self.plot.addCurve(-x, y, legend="negative") + + # Make sure there is an active curve and it is the positive one + self.plot.setActiveCurve("positive") + + # Add two ROIs + ddict = {} + ddict["positive"] = {"from": 10, "to": 20, "type":"X"} + ddict["negative"] = {"from": -20, "to": -10, "type":"X"} + self.widget.roiWidget.setRois(ddict) + + # And calculate the expected output + self.widget.calculateROIs() + + output = self.widget.roiWidget.getRois() + self.assertEqual(output["positive"]["rawcounts"], + y[ddict["positive"]["from"]:ddict["positive"]["to"]+1].sum(), + "Calculation failed on positive X coordinates") + + # Set the curve with negative X coordinates as active + self.plot.setActiveCurve("negative") + + # the ROIs should have been automatically updated + output = self.widget.roiWidget.getRois() + selection = numpy.nonzero((-x >= output["negative"]["from"]) & \ + (-x <= output["negative"]["to"]))[0] + self.assertEqual(output["negative"]["rawcounts"], + y[selection].sum(), "Calculation failed on negative X coordinates") + +def suite(): + test_suite = unittest.TestSuite() + for TestClass in (TestCurvesROIWidget,): + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestClass)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot/test/testInteraction.py b/silx/gui/plot/test/testInteraction.py new file mode 100644 index 0000000..074a7cd --- /dev/null +++ b/silx/gui/plot/test/testInteraction.py @@ -0,0 +1,89 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Tests from interaction state machines""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "18/02/2016" + + +import unittest + +from silx.gui.plot import Interaction + + +class TestInteraction(unittest.TestCase): + def testClickOrDrag(self): + """Minimalistic test for click or drag state machine.""" + events = [] + + class TestClickOrDrag(Interaction.ClickOrDrag): + def click(self, x, y, btn): + events.append(('click', x, y, btn)) + + def beginDrag(self, x, y): + events.append(('beginDrag', x, y)) + + def drag(self, x, y): + events.append(('drag', x, y)) + + def endDrag(self, x, y): + events.append(('endDrag', x, y)) + + clickOrDrag = TestClickOrDrag() + + # click + clickOrDrag.handleEvent('press', 10, 10, Interaction.LEFT_BTN) + self.assertEqual(len(events), 0) + + clickOrDrag.handleEvent('release', 10, 10, Interaction.LEFT_BTN) + self.assertEqual(len(events), 1) + self.assertEqual(events[0], ('click', 10, 10, Interaction.LEFT_BTN)) + + # drag + events = [] + clickOrDrag.handleEvent('press', 10, 10, Interaction.LEFT_BTN) + self.assertEqual(len(events), 0) + clickOrDrag.handleEvent('move', 15, 10) + self.assertEqual(len(events), 2) # Received beginDrag and drag + self.assertEqual(events[0], ('beginDrag', 10, 10)) + self.assertEqual(events[1], ('drag', 15, 10)) + clickOrDrag.handleEvent('move', 20, 10) + self.assertEqual(len(events), 3) + self.assertEqual(events[-1], ('drag', 20, 10)) + clickOrDrag.handleEvent('release', 20, 10, Interaction.LEFT_BTN) + self.assertEqual(len(events), 4) + self.assertEqual(events[-1], ('endDrag', (10, 10), (20, 10))) + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestInteraction)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot/test/testLegendSelector.py b/silx/gui/plot/test/testLegendSelector.py new file mode 100644 index 0000000..371197f --- /dev/null +++ b/silx/gui/plot/test/testLegendSelector.py @@ -0,0 +1,143 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for PlotWidget""" + +__authors__ = ["T. Rueter", "T. Vincent"] +__license__ = "MIT" +__date__ = "05/12/2016" + + +import logging +import unittest + +from silx.gui import qt +from silx.gui.test.utils import TestCaseQt +from silx.gui.plot import LegendSelector + + +logging.basicConfig() +_logger = logging.getLogger(__name__) + + +class TestLegendSelector(TestCaseQt): + """Basic test for LegendSelector""" + + def testLegendSelector(self): + """Test copied from __main__ of LegendSelector in PyMca""" + class Notifier(qt.QObject): + def __init__(self): + qt.QObject.__init__(self) + self.chk = True + + def signalReceived(self, **kw): + obj = self.sender() + _logger.info('NOTIFIER -- signal received\n\tsender: %s', + str(obj)) + + notifier = Notifier() + + legends = ['Legend0', + 'Legend1', + 'Long Legend 2', + 'Foo Legend 3', + 'Even Longer Legend 4', + 'Short Leg 5', + 'Dot symbol 6', + 'Comma symbol 7'] + colors = [qt.Qt.darkRed, qt.Qt.green, qt.Qt.yellow, qt.Qt.darkCyan, + qt.Qt.blue, qt.Qt.darkBlue, qt.Qt.red, qt.Qt.darkYellow] + symbols = ['o', 't', '+', 'x', 's', 'd', '.', ','] + + win = LegendSelector.LegendListView() + # win = LegendListContextMenu() + # win = qt.QWidget() + # layout = qt.QVBoxLayout() + # layout.setContentsMargins(0,0,0,0) + llist = [] + + for _idx, (l, c, s) in enumerate(zip(legends, colors, symbols)): + ddict = { + 'color': qt.QColor(c), + 'linewidth': 4, + 'symbol': s, + } + legend = l + llist.append((legend, ddict)) + # item = qt.QListWidgetItem(win) + # legendWidget = LegendListItemWidget(l) + # legendWidget.icon.setSymbol(s) + # legendWidget.icon.setColor(qt.QColor(c)) + # layout.addWidget(legendWidget) + # win.setItemWidget(item, legendWidget) + + # win = LegendListItemWidget('Some Legend 1') + # print(llist) + model = LegendSelector.LegendModel(legendList=llist) + win.setModel(model) + win.setSelectionModel(qt.QItemSelectionModel(model)) + win.setContextMenu() + # print('Edit triggers: %d'%win.editTriggers()) + + # win = LegendListWidget(None, legends) + # win[0].updateItem(ddict) + # win.setLayout(layout) + win.sigLegendSignal.connect(notifier.signalReceived) + win.show() + + win.clear() + win.setLegendList(llist) + + self.qWaitForWindowExposed(win) + + +class TestRenameCurveDialog(TestCaseQt): + """Basic test for RenameCurveDialog""" + + def testDialog(self): + """Create dialog, change name and press OK""" + self.dialog = LegendSelector.RenameCurveDialog( + None, 'curve1', ['curve1', 'curve2', 'curve3']) + self.dialog.open() + self.qWaitForWindowExposed(self.dialog) + self.keyClicks(self.dialog.lineEdit, 'changed') + self.mouseClick(self.dialog.okButton, qt.Qt.LeftButton) + self.qapp.processEvents() + ret = self.dialog.result() + self.assertEqual(ret, qt.QDialog.Accepted) + newName = self.dialog.getText() + self.assertEqual(newName, 'curve1changed') + del self.dialog + + +def suite(): + test_suite = unittest.TestSuite() + for TestClass in (TestLegendSelector, TestRenameCurveDialog): + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestClass)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot/test/testMaskToolsWidget.py b/silx/gui/plot/test/testMaskToolsWidget.py new file mode 100644 index 0000000..0c11928 --- /dev/null +++ b/silx/gui/plot/test/testMaskToolsWidget.py @@ -0,0 +1,295 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for MaskToolsWidget""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "24/01/2017" + + +import logging +import os.path +import unittest + +import numpy + +from silx.gui import qt +from silx.test.utils import temp_dir, ParametricTestCase +from silx.gui.test.utils import TestCaseQt, getQToolButtonFromAction +from silx.gui.plot import PlotWindow, MaskToolsWidget + +try: + import fabio +except ImportError: + fabio = None + + +logging.basicConfig() +_logger = logging.getLogger(__name__) + + +class TestMaskToolsWidget(TestCaseQt, ParametricTestCase): + """Basic test for MaskToolsWidget""" + + def setUp(self): + super(TestMaskToolsWidget, self).setUp() + self.plot = PlotWindow() + + self.widget = MaskToolsWidget.MaskToolsDockWidget(plot=self.plot, name='TEST') + self.plot.addDockWidget(qt.Qt.BottomDockWidgetArea, self.widget) + + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + self.maskWidget = self.widget.widget() + + def tearDown(self): + del self.maskWidget + del self.widget + + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + + super(TestMaskToolsWidget, self).tearDown() + + def testEmptyPlot(self): + """Empty plot, display MaskToolsDockWidget, toggle multiple masks""" + self.maskWidget.setMultipleMasks('single') + self.qapp.processEvents() + + self.maskWidget.setMultipleMasks('exclusive') + self.qapp.processEvents() + + def _drag(self): + """Drag from plot center to offset position""" + plot = self.plot.centralWidget() + xCenter, yCenter = plot.width() // 2, plot.height() // 2 + offset = min(plot.width(), plot.height()) // 10 + + pos0 = xCenter, yCenter + pos1 = xCenter + offset, yCenter + offset + + self.mouseMove(plot, pos=pos0) + self.mousePress(plot, qt.Qt.LeftButton, pos=pos0) + self.mouseMove(plot, pos=pos1) + self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos1) + + def _drawPolygon(self): + """Draw a star polygon in the plot""" + plot = self.plot.centralWidget() + x, y = plot.width() // 2, plot.height() // 2 + offset = min(plot.width(), plot.height()) // 10 + + star = [(x, y + offset), + (x - offset, y - offset), + (x + offset, y), + (x - offset, y), + (x + offset, y - offset)] + + for pos in star: + self.mouseMove(plot, pos=pos) + btn = qt.Qt.LeftButton if pos != star[-1] else qt.Qt.RightButton + self.mouseClick(plot, btn, pos=pos) + + def _drawPencil(self): + """Draw a star polygon in the plot""" + plot = self.plot.centralWidget() + x, y = plot.width() // 2, plot.height() // 2 + offset = min(plot.width(), plot.height()) // 10 + + star = [(x, y + offset), + (x - offset, y - offset), + (x + offset, y), + (x - offset, y), + (x + offset, y - offset)] + + self.mouseMove(plot, pos=star[0]) + self.mousePress(plot, qt.Qt.LeftButton, pos=star[0]) + for pos in star: + self.mouseMove(plot, pos=pos) + self.mouseRelease( + plot, qt.Qt.LeftButton, pos=star[-1]) + + def testWithAnImage(self): + """Plot with an image: test MaskToolsWidget interactions""" + + # Add and remove a image (this should enable/disable GUI + change mask) + self.plot.addImage(numpy.random.random(1024**2).reshape(1024, 1024), + legend='test') + self.qapp.processEvents() + + self.plot.remove('test', kind='image') + self.qapp.processEvents() + + tests = [((0, 0), (1, 1)), + ((1000, 1000), (1, 1)), + ((0, 0), (-1, -1)), + ((1000, 1000), (-1, -1))] + + for origin, scale in tests: + with self.subTest(origin=origin, scale=scale): + self.plot.addImage(numpy.arange(1024**2).reshape(1024, 1024), + legend='test', + origin=origin, + scale=scale) + self.qapp.processEvents() + + # Test draw rectangle # + toolButton = getQToolButtonFromAction(self.maskWidget.rectAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + + # mask + self.maskWidget.maskStateGroup.button(1).click() + self.qapp.processEvents() + self._drag() + self.assertFalse( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + + # unmask same region + self.maskWidget.maskStateGroup.button(0).click() + self.qapp.processEvents() + self._drag() + self.assertTrue( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + + # Test draw polygon # + toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + + # mask + self.maskWidget.maskStateGroup.button(1).click() + self.qapp.processEvents() + self._drawPolygon() + self.assertFalse( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + + # unmask same region + self.maskWidget.maskStateGroup.button(0).click() + self.qapp.processEvents() + self._drawPolygon() + self.assertTrue( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + + # Test draw pencil # + toolButton = getQToolButtonFromAction(self.maskWidget.pencilAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + + self.maskWidget.pencilSpinBox.setValue(10) + self.qapp.processEvents() + + # mask + self.maskWidget.maskStateGroup.button(1).click() + self.qapp.processEvents() + self._drawPencil() + self.assertFalse( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + + # unmask same region + self.maskWidget.maskStateGroup.button(0).click() + self.qapp.processEvents() + self._drawPencil() + self.assertTrue( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + + # Test no draw tool # + toolButton = getQToolButtonFromAction(self.maskWidget.browseAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + + self.plot.clear() + + def __loadSave(self, file_format): + """Plot with an image: test MaskToolsWidget operations""" + self.plot.addImage(numpy.arange(1024**2).reshape(1024, 1024), + legend='test') + self.qapp.processEvents() + + # Draw a polygon mask + toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + self._drawPolygon() + + ref_mask = self.maskWidget.getSelectionMask() + self.assertFalse(numpy.all(numpy.equal(ref_mask, 0))) + + with temp_dir() as tmp: + mask_filename = os.path.join(tmp, 'mask.' + file_format) + self.maskWidget.save(mask_filename, file_format) + + self.maskWidget.resetSelectionMask() + self.assertTrue( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + + self.maskWidget.load(mask_filename) + self.assertTrue(numpy.all(numpy.equal( + self.maskWidget.getSelectionMask(), ref_mask))) + + def testLoadSaveNpy(self): + self.__loadSave("npy") + + def testLoadSaveFit2D(self): + if fabio is None: + self.skipTest("Fabio is missing") + self.__loadSave("msk") + + def testSigMaskChangedEmitted(self): + self.plot.addImage(numpy.arange(512**2).reshape(512, 512), + legend='test') + self.plot.resetZoom() + self.qapp.processEvents() + + l = [] + + def slot(): + l.append(1) + + self.maskWidget.sigMaskChanged.connect(slot) + + # rectangle mask + toolButton = getQToolButtonFromAction(self.maskWidget.rectAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + self.maskWidget.maskStateGroup.button(1).click() + self.qapp.processEvents() + self._drag() + + self.assertGreater(len(l), 0) + + +def suite(): + test_suite = unittest.TestSuite() + for TestClass in (TestMaskToolsWidget,): + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestClass)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot/test/testPlot.py b/silx/gui/plot/test/testPlot.py new file mode 100644 index 0000000..25e7511 --- /dev/null +++ b/silx/gui/plot/test/testPlot.py @@ -0,0 +1,633 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for Plot""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "05/12/2016" + + +import unittest +from functools import reduce +from silx.test.utils import ParametricTestCase + +import numpy + +from silx.gui.plot.Plot import Plot +from silx.gui.plot.items.histogram import _getHistogramCurve, _computeEdges + + +class TestPlot(unittest.TestCase): + """Basic tests of Plot without backend""" + + def testPlotTitleLabels(self): + """Create a Plot and set the labels""" + + plot = Plot(backend='none') + + title, xlabel, ylabel = 'the title', 'x label', 'y label' + plot.setGraphTitle(title) + plot.setGraphXLabel(xlabel) + plot.setGraphYLabel(ylabel) + + self.assertEqual(plot.getGraphTitle(), title) + self.assertEqual(plot.getGraphXLabel(), xlabel) + self.assertEqual(plot.getGraphYLabel(), ylabel) + + def testAddNoRemove(self): + """add objects to the Plot""" + + plot = Plot(backend='none') + plot.addCurve(x=(1, 2, 3), y=(3, 2, 1)) + plot.addImage(numpy.arange(100.).reshape(10, -1)) + plot.addItem( + numpy.array((1., 10.)), numpy.array((10., 10.)), shape="rectangle") + plot.addXMarker(10.) + + +class TestPlotRanges(ParametricTestCase): + """Basic tests of Plot data ranges without backend""" + + _getValidValues = {True: lambda ar: ar > 0, + False: lambda ar: numpy.ones(shape=ar.shape, + dtype=bool)} + + @staticmethod + def _getRanges(arrays, are_logs): + gen = (TestPlotRanges._getValidValues[is_log](ar) + for (ar, is_log) in zip(arrays, are_logs)) + indices = numpy.where(reduce(numpy.logical_and, gen))[0] + if len(indices) > 0: + ranges = [(ar[indices[0]], ar[indices[-1]]) for ar in arrays] + else: + ranges = [None] * len(arrays) + + return ranges + + @staticmethod + def _getRangesMinmax(ranges): + # TODO : error if None in ranges. + rangeMin = numpy.min([rng[0] for rng in ranges]) + rangeMax = numpy.max([rng[1] for rng in ranges]) + return rangeMin, rangeMax + + def testDataRangeNoPlot(self): + """empty plot data range""" + + plot = Plot(backend='none') + + for logX, logY in ((False, False), + (True, False), + (True, True), + (False, True), + (False, False)): + with self.subTest(logX=logX, logY=logY): + plot.setXAxisLogarithmic(logX) + plot.setYAxisLogarithmic(logY) + dataRange = plot.getDataRange() + self.assertIsNone(dataRange.x) + self.assertIsNone(dataRange.y) + self.assertIsNone(dataRange.yright) + + def testDataRangeLeft(self): + """left axis range""" + + plot = Plot(backend='none') + + xData = numpy.arange(10) - 4.9 # range : -4.9 , 4.1 + yData = numpy.arange(10) - 6.9 # range : -6.9 , 2.1 + + plot.addCurve(x=xData, + y=yData, + legend='plot_0', + yaxis='left') + + for logX, logY in ((False, False), + (True, False), + (True, True), + (False, True), + (False, False)): + with self.subTest(logX=logX, logY=logY): + plot.setXAxisLogarithmic(logX) + plot.setYAxisLogarithmic(logY) + dataRange = plot.getDataRange() + xRange, yRange = self._getRanges([xData, yData], + [logX, logY]) + self.assertSequenceEqual(dataRange.x, xRange) + self.assertSequenceEqual(dataRange.y, yRange) + self.assertIsNone(dataRange.yright) + + def testDataRangeRight(self): + """right axis range""" + + plot = Plot(backend='none') + xData = numpy.arange(10) - 4.9 # range : -4.9 , 4.1 + yData = numpy.arange(10) - 6.9 # range : -6.9 , 2.1 + plot.addCurve(x=xData, + y=yData, + legend='plot_0', + yaxis='right') + + for logX, logY in ((False, False), + (True, False), + (True, True), + (False, True), + (False, False)): + with self.subTest(logX=logX, logY=logY): + plot.setXAxisLogarithmic(logX) + plot.setYAxisLogarithmic(logY) + dataRange = plot.getDataRange() + xRange, yRange = self._getRanges([xData, yData], + [logX, logY]) + self.assertSequenceEqual(dataRange.x, xRange) + self.assertIsNone(dataRange.y) + self.assertSequenceEqual(dataRange.yright, yRange) + + def testDataRangeImage(self): + """image data range""" + + origin = (-10, 25) + scale = (3., 8.) + image = numpy.arange(100.).reshape(20, 5) + + plot = Plot(backend='none') + plot.addImage(image, + origin=origin, scale=scale) + + xRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0] + yRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1] + + ranges = {(False, False): (xRange, yRange), + (True, False): (None, None), + (True, True): (None, None), + (False, True): (None, None)} + + for logX, logY in ((False, False), + (True, False), + (True, True), + (False, True), + (False, False)): + with self.subTest(logX=logX, logY=logY): + plot.setXAxisLogarithmic(logX) + plot.setYAxisLogarithmic(logY) + dataRange = plot.getDataRange() + xRange, yRange = ranges[logX, logY] + self.assertTrue(numpy.array_equal(dataRange.x, xRange), + msg='{0} != {1}'.format(dataRange.x, xRange)) + self.assertTrue(numpy.array_equal(dataRange.y, yRange), + msg='{0} != {1}'.format(dataRange.y, yRange)) + self.assertIsNone(dataRange.yright) + + def testDataRangeLeftRight(self): + """right+left axis range""" + + plot = Plot(backend='none') + + xData_l = numpy.arange(10) - 0.9 # range : -0.9 , 8.1 + yData_l = numpy.arange(10) - 1.9 # range : -1.9 , 7.1 + plot.addCurve(x=xData_l, + y=yData_l, + legend='plot_l', + yaxis='left') + + xData_r = numpy.arange(10) - 4.9 # range : -4.9 , 4.1 + yData_r = numpy.arange(10) - 6.9 # range : -6.9 , 2.1 + plot.addCurve(x=xData_r, + y=yData_r, + legend='plot_r', + yaxis='right') + + for logX, logY in ((False, False), + (True, False), + (True, True), + (False, True), + (False, False)): + with self.subTest(logX=logX, logY=logY): + plot.setXAxisLogarithmic(logX) + plot.setYAxisLogarithmic(logY) + dataRange = plot.getDataRange() + xRangeL, yRangeL = self._getRanges([xData_l, yData_l], + [logX, logY]) + xRangeR, yRangeR = self._getRanges([xData_r, yData_r], + [logX, logY]) + xRangeLR = self._getRangesMinmax([xRangeL, xRangeR]) + self.assertSequenceEqual(dataRange.x, xRangeLR) + self.assertSequenceEqual(dataRange.y, yRangeL) + self.assertSequenceEqual(dataRange.yright, yRangeR) + + def testDataRangeCurveImage(self): + """right+left+image axis range""" + + # overlapping ranges : + # image sets x min and y max + # plot_left sets y min + # plot_right sets x max (and yright) + plot = Plot(backend='none') + + origin = (-10, 5) + scale = (3., 8.) + image = numpy.arange(100.).reshape(20, 5) + + plot.addImage(image, + origin=origin, scale=scale, legend='image') + + xData_l = numpy.arange(10) - 0.9 # range : -0.9 , 8.1 + yData_l = numpy.arange(10) - 1.9 # range : -1.9 , 7.1 + plot.addCurve(x=xData_l, + y=yData_l, + legend='plot_l', + yaxis='left') + + xData_r = numpy.arange(10) + 4.1 # range : 4.1 , 13.1 + yData_r = numpy.arange(10) - 0.9 # range : -0.9 , 8.1 + plot.addCurve(x=xData_r, + y=yData_r, + legend='plot_r', + yaxis='right') + + imgXRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0] + imgYRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1] + + for logX, logY in ((False, False), + (True, False), + (True, True), + (False, True), + (False, False)): + with self.subTest(logX=logX, logY=logY): + plot.setXAxisLogarithmic(logX) + plot.setYAxisLogarithmic(logY) + dataRange = plot.getDataRange() + xRangeL, yRangeL = self._getRanges([xData_l, yData_l], + [logX, logY]) + xRangeR, yRangeR = self._getRanges([xData_r, yData_r], + [logX, logY]) + if logX or logY: + xRangeLR = self._getRangesMinmax([xRangeL, xRangeR]) + else: + xRangeLR = self._getRangesMinmax([xRangeL, + xRangeR, + imgXRange]) + yRangeL = self._getRangesMinmax([yRangeL, imgYRange]) + self.assertSequenceEqual(dataRange.x, xRangeLR) + self.assertSequenceEqual(dataRange.y, yRangeL) + self.assertSequenceEqual(dataRange.yright, yRangeR) + + def testDataRangeImageNegativeScaleX(self): + """image data range, negative scale""" + + origin = (-10, 25) + scale = (-3., 8.) + image = numpy.arange(100.).reshape(20, 5) + + plot = Plot(backend='none') + plot.addImage(image, + origin=origin, scale=scale) + + xRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0] + xRange.sort() # negative scale! + yRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1] + + ranges = {(False, False): (xRange, yRange), + (True, False): (None, None), + (True, True): (None, None), + (False, True): (None, None)} + + for logX, logY in ((False, False), + (True, False), + (True, True), + (False, True), + (False, False)): + with self.subTest(logX=logX, logY=logY): + plot.setXAxisLogarithmic(logX) + plot.setYAxisLogarithmic(logY) + dataRange = plot.getDataRange() + xRange, yRange = ranges[logX, logY] + self.assertTrue(numpy.array_equal(dataRange.x, xRange), + msg='{0} != {1}'.format(dataRange.x, xRange)) + self.assertTrue(numpy.array_equal(dataRange.y, yRange), + msg='{0} != {1}'.format(dataRange.y, yRange)) + self.assertIsNone(dataRange.yright) + + def testDataRangeImageNegativeScaleY(self): + """image data range, negative scale""" + + origin = (-10, 25) + scale = (3., -8.) + image = numpy.arange(100.).reshape(20, 5) + + plot = Plot(backend='none') + plot.addImage(image, + origin=origin, scale=scale) + + xRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0] + yRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1] + yRange.sort() # negative scale! + + ranges = {(False, False): (xRange, yRange), + (True, False): (None, None), + (True, True): (None, None), + (False, True): (None, None)} + + for logX, logY in ((False, False), + (True, False), + (True, True), + (False, True), + (False, False)): + with self.subTest(logX=logX, logY=logY): + plot.setXAxisLogarithmic(logX) + plot.setYAxisLogarithmic(logY) + dataRange = plot.getDataRange() + xRange, yRange = ranges[logX, logY] + self.assertTrue(numpy.array_equal(dataRange.x, xRange), + msg='{0} != {1}'.format(dataRange.x, xRange)) + self.assertTrue(numpy.array_equal(dataRange.y, yRange), + msg='{0} != {1}'.format(dataRange.y, yRange)) + self.assertIsNone(dataRange.yright) + + def testDataRangeHiddenCurve(self): + """curves with a hidden curve""" + plot = Plot(backend='none') + plot.addCurve((0, 1), (0, 1), legend='shown') + plot.addCurve((0, 1, 2), (5, 5, 5), legend='hidden') + range1 = plot.getDataRange() + self.assertEqual(range1.x, (0, 2)) + self.assertEqual(range1.y, (0, 5)) + plot.hideCurve('hidden') + range2 = plot.getDataRange() + self.assertEqual(range2.x, (0, 1)) + self.assertEqual(range2.y, (0, 1)) + + +class TestPlotGetCurveImage(unittest.TestCase): + """Test of plot getCurve and getImage methods""" + + def testGetCurve(self): + """Plot.getCurve and Plot.getActiveCurve tests""" + + plot = Plot(backend='none') + + # No curve + curve = plot.getCurve() + self.assertIsNone(curve) # No curve + + plot.setActiveCurveHandling(True) + plot.addCurve(x=(0, 1), y=(0, 1), legend='curve 0') + plot.addCurve(x=(0, 1), y=(0, 1), legend='curve 1') + plot.addCurve(x=(0, 1), y=(0, 1), legend='curve 2') + plot.setActiveCurve('curve 0') + + # Active curve + active = plot.getActiveCurve() + self.assertEqual(active.getLegend(), 'curve 0') + curve = plot.getCurve() + self.assertEqual(curve.getLegend(), 'curve 0') + + # No active curve and curves + plot.setActiveCurveHandling(False) + active = plot.getActiveCurve() + self.assertIsNone(active) # No active curve + curve = plot.getCurve() + self.assertEqual(curve.getLegend(), 'curve 2') # Last added curve + + # Last curve hidden + plot.hideCurve('curve 2', True) + curve = plot.getCurve() + self.assertEqual(curve.getLegend(), 'curve 1') # Last added curve + + # All curves hidden + plot.hideCurve('curve 1', True) + plot.hideCurve('curve 0', True) + curve = plot.getCurve() + self.assertIsNone(curve) + + def testGetCurveOldApi(self): + """old API Plot.getCurve and Plot.getActiveCurve tests""" + + plot = Plot(backend='none') + + # No curve + curve = plot.getCurve() + self.assertIsNone(curve) # No curve + + plot.setActiveCurveHandling(True) + x = numpy.arange(10.).astype(numpy.float32) + y = x * x; + plot.addCurve(x=x, y=y, legend='curve 0', info=["whatever"]) + plot.addCurve(x=x, y=2*x, legend='curve 1', info="anything") + plot.setActiveCurve('curve 0') + + # Active curve (4 elements) + xOut, yOut, legend, info = plot.getActiveCurve()[:4] + self.assertEqual(legend, 'curve 0') + self.assertTrue(numpy.allclose(xOut, x), 'curve 0 wrong x data') + self.assertTrue(numpy.allclose(yOut, y), 'curve 0 wrong y data') + + # Active curve (5 elements) + xOut, yOut, legend, info, params = plot.getCurve("curve 1") + self.assertEqual(legend, 'curve 1') + self.assertEqual(info, 'anything') + self.assertTrue(numpy.allclose(xOut, x), 'curve 1 wrong x data') + self.assertTrue(numpy.allclose(yOut, 2*x), 'curve 1 wrong y data') + + def testGetImage(self): + """Plot.getImage and Plot.getActiveImage tests""" + + plot = Plot(backend='none') + + # No image + image = plot.getImage() + self.assertIsNone(image) + + plot.addImage(((0, 1), (2, 3)), legend='image 0', replace=False) + plot.addImage(((0, 1), (2, 3)), legend='image 1', replace=False) + + # Active image + active = plot.getActiveImage() + self.assertEqual(active.getLegend(), 'image 0') + image = plot.getImage() + self.assertEqual(image.getLegend(), 'image 0') + + # No active image + plot.addImage(((0, 1), (2, 3)), legend='image 2', replace=False) + plot.setActiveImage(None) + active = plot.getActiveImage() + self.assertIsNone(active) + image = plot.getImage() + self.assertEqual(image.getLegend(), 'image 2') + + # Active image + plot.setActiveImage('image 1') + active = plot.getActiveImage() + self.assertEqual(active.getLegend(), 'image 1') + image = plot.getImage() + self.assertEqual(image.getLegend(), 'image 1') + + def testGetImageOldApi(self): + """Plot.getImage and Plot.getActiveImage old API tests""" + + plot = Plot(backend='none') + + # No image + image = plot.getImage() + self.assertIsNone(image) + + image = numpy.arange(10).astype(numpy.float32) + image.shape = 5, 2 + + plot.addImage(image, legend='image 0', info=["Hi!"], replace=False) + + # Active image + data, legend, info, something, params = plot.getActiveImage() + self.assertEqual(legend, 'image 0') + self.assertEqual(info, ["Hi!"]) + self.assertTrue(numpy.allclose(data, image), "image 0 data not correct") + + def testGetAllImages(self): + """Plot.getAllImages test""" + + plot = Plot(backend='none') + + # No image + images = plot.getAllImages() + self.assertEqual(len(images), 0) + + # 2 images + data = numpy.arange(100).reshape(10, 10) + plot.addImage(data, legend='1', replace=False) + plot.addImage(data, origin=(10, 10), legend='2', replace=False) + images = plot.getAllImages(just_legend=True) + self.assertEqual(list(images), ['1', '2']) + images = plot.getAllImages(just_legend=False) + self.assertEqual(len(images), 2) + self.assertEqual(images[0].getLegend(), '1') + self.assertEqual(images[1].getLegend(), '2') + + +class TestPlotAddScatter(unittest.TestCase): + """Test of plot addScatter""" + + def testAddGetScatter(self): + + plot = Plot(backend='none') + + # No curve + scatter = plot._getItem(kind="scatter") + self.assertIsNone(scatter) # No curve + + plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 0') + plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 1') + plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 2') + plot._setActiveItem('scatter', 'scatter 0') + + # Active scatter + active = plot._getActiveItem(kind='scatter') + self.assertEqual(active.getLegend(), 'scatter 0') + + # check default values + self.assertAlmostEqual(active.getSymbolSize(), active._DEFAULT_SYMBOL_SIZE) + self.assertEqual(active.getSymbol(), "o") + self.assertAlmostEqual(active.getAlpha(), 1.0) + + # modify parameters + active.setSymbolSize(20.5) + active.setSymbol("d") + active.setAlpha(0.777) + + s0 = plot.getScatter("scatter 0") + + self.assertAlmostEqual(s0.getSymbolSize(), 20.5) + self.assertEqual(s0.getSymbol(), "d") + self.assertAlmostEqual(s0.getAlpha(), 0.777) + + scatter1 = plot._getItem(kind='scatter', legend='scatter 1') + self.assertEqual(scatter1.getLegend(), 'scatter 1') + + def testGetAllScatters(self): + """Plot.getAllImages test""" + + plot = Plot(backend='none') + + scatters = plot._getItems(kind='scatter') + self.assertEqual(len(scatters), 0) + + plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 0') + plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 1') + plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 2') + + scatters = plot._getItems(kind='scatter') + self.assertEqual(len(scatters), 3) + self.assertEqual(scatters[0].getLegend(), 'scatter 0') + self.assertEqual(scatters[2].getLegend(), 'scatter 2') + + scatters = plot._getItems(kind='scatter', just_legend=True) + self.assertEqual(len(scatters), 3) + self.assertEqual(list(scatters), ['scatter 0', 'scatter 1', 'scatter 2']) + + +class TestPlotHistogram(unittest.TestCase): + """Basic tests for histogram.""" + + def testEdges(self): + x = numpy.array([0, 1, 2]) + edgesRight = numpy.array([0, 1, 2, 3]) + edgesLeft = numpy.array([-1, 0, 1, 2]) + edgesCenter = numpy.array([-0.5, 0.5, 1.5, 2.5]) + + # testing x values for right + edges = _computeEdges(x, 'right') + numpy.testing.assert_array_equal(edges, edgesRight) + + edges = _computeEdges(x, 'center') + numpy.testing.assert_array_equal(edges, edgesCenter) + + edges = _computeEdges(x, 'left') + numpy.testing.assert_array_equal(edges, edgesLeft) + + def testHistogramCurve(self): + y = numpy.array([3, 2, 5]) + edges = numpy.array([0, 1, 2, 3]) + + xHisto, yHisto = _getHistogramCurve(y, edges) + numpy.testing.assert_array_equal( + yHisto, numpy.array([3, 3, 2, 2, 5, 5])) + + y = numpy.array([-3, 2, 5, 0]) + edges = numpy.array([-2, -1, 0, 1, 2]) + xHisto, yHisto = _getHistogramCurve(y, edges) + numpy.testing.assert_array_equal( + yHisto, numpy.array([-3, -3, 2, 2, 5, 5, 0, 0])) + + +def suite(): + test_suite = unittest.TestSuite() + for TestClass in (TestPlot, TestPlotRanges, TestPlotGetCurveImage, + TestPlotHistogram, TestPlotAddScatter): + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestClass)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot/test/testPlotInteraction.py b/silx/gui/plot/test/testPlotInteraction.py new file mode 100644 index 0000000..25f57a9 --- /dev/null +++ b/silx/gui/plot/test/testPlotInteraction.py @@ -0,0 +1,167 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Tests of plot interaction, through a PlotWidget""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "13/10/2016" + + +import unittest +from silx.gui import qt +from silx.gui.plot.test.testPlotWidget import _PlotWidgetTest + + +class _SignalDump(object): + """Callable object that store passed arguments in a list""" + + def __init__(self): + self._received = [] + + def __call__(self, *args): + self._received.append(args) + + @property + def received(self): + """Return a shallow copy of the list of received arguments""" + return list(self._received) + + +class TestSelectPolygon(_PlotWidgetTest): + """Test polygon selection interaction""" + + def _interactionModeChanged(self, source): + """Check that source received in event is the correct one""" + self.assertEqual(source, self) + + def _draw(self, polygon): + """Draw a polygon in the plot + + :param polygon: List of points (x, y) of the polygon (not closed) + """ + plot = self.plot.centralWidget() + + dump = _SignalDump() + self.plot.sigPlotSignal.connect(dump) + + for pos in polygon: + self.mouseMove(plot, pos=pos) + btn = qt.Qt.LeftButton if pos != polygon[-1] else qt.Qt.RightButton + self.mouseClick(plot, btn, pos=pos) + + self.plot.sigPlotSignal.disconnect(dump) + return [args[0] for args in dump.received] + + def test(self): + """Test draw polygons + events""" + self.plot.sigInteractiveModeChanged.connect( + self._interactionModeChanged) + + self.plot.setInteractiveMode( + 'draw', shape='polygon', label='test', source=self) + interaction = self.plot.getInteractiveMode() + + self.assertEqual(interaction['mode'], 'draw') + self.assertEqual(interaction['shape'], 'polygon') + + self.plot.sigInteractiveModeChanged.disconnect( + self._interactionModeChanged) + + plot = self.plot.centralWidget() + xCenter, yCenter = plot.width() // 2, plot.height() // 2 + offset = min(plot.width(), plot.height()) // 10 + + # Star polygon + star = [(xCenter, yCenter + offset), + (xCenter - offset, yCenter - offset), + (xCenter + offset, yCenter), + (xCenter - offset, yCenter), + (xCenter + offset, yCenter - offset)] + + # Draw while dumping signals + events = self._draw(star) + + # Test last event + drawEvents = [event for event in events + if event['event'].startswith('drawing')] + self.assertEqual(drawEvents[-1]['event'], 'drawingFinished') + self.assertEqual(len(drawEvents[-1]['points']), 6) + + # Large square + largeSquare = [(xCenter - offset, yCenter - offset), + (xCenter + offset, yCenter - offset), + (xCenter + offset, yCenter + offset), + (xCenter - offset, yCenter + offset)] + + # Draw while dumping signals + events = self._draw(largeSquare) + + # Test last event + drawEvents = [event for event in events + if event['event'].startswith('drawing')] + self.assertEqual(drawEvents[-1]['event'], 'drawingFinished') + self.assertEqual(len(drawEvents[-1]['points']), 5) + + # Rectangle too thin along X: Some points are ignored + thinRectX = [(xCenter, yCenter - offset), + (xCenter, yCenter + offset), + (xCenter + 1, yCenter + offset), + (xCenter + 1, yCenter - offset)] + + # Draw while dumping signals + events = self._draw(thinRectX) + + # Test last event + drawEvents = [event for event in events + if event['event'].startswith('drawing')] + self.assertEqual(drawEvents[-1]['event'], 'drawingFinished') + self.assertEqual(len(drawEvents[-1]['points']), 3) + + # Rectangle too thin along Y: Some points are ignored + thinRectY = [(xCenter - offset, yCenter), + (xCenter + offset, yCenter), + (xCenter + offset, yCenter + 1), + (xCenter - offset, yCenter + 1)] + + # Draw while dumping signals + events = self._draw(thinRectY) + + # Test last event + drawEvents = [event for event in events + if event['event'].startswith('drawing')] + self.assertEqual(drawEvents[-1]['event'], 'drawingFinished') + self.assertEqual(len(drawEvents[-1]['points']), 3) + + +def suite(): + test_suite = unittest.TestSuite() + for TestClass in (TestSelectPolygon,): + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestClass)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot/test/testPlotTools.py b/silx/gui/plot/test/testPlotTools.py new file mode 100644 index 0000000..1d5e148 --- /dev/null +++ b/silx/gui/plot/test/testPlotTools.py @@ -0,0 +1,203 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for PlotTools""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "05/12/2016" + + +import numpy +import unittest + +from silx.test.utils import ParametricTestCase, TestLogging +from silx.gui.test.utils import ( + qWaitForWindowExposedAndActivate, TestCaseQt, getQToolButtonFromAction) +from silx.gui import qt +from silx.gui.plot import Plot2D, PlotWindow, PlotTools + + +# Makes sure a QApplication exists +_qapp = qt.QApplication.instance() or qt.QApplication([]) + + +def _tearDownDocTest(docTest): + """Tear down to use for test from docstring. + + Checks that plot widget is displayed + """ + plot = docTest.globs['plot'] + qWaitForWindowExposedAndActivate(plot) + plot.setAttribute(qt.Qt.WA_DeleteOnClose) + plot.close() + del plot + +# Disable doctest because of +# "NameError: name 'numpy' is not defined" +# +# import doctest +# positionInfoTestSuite = doctest.DocTestSuite( +# PlotTools, tearDown=_tearDownDocTest, +# optionflags=doctest.ELLIPSIS) +# """Test suite of tests from PlotTools docstrings. +# +# Test PositionInfo and ProfileToolBar docstrings. +# """ + + +class TestPositionInfo(TestCaseQt): + """Tests for PositionInfo widget.""" + + def setUp(self): + super(TestPositionInfo, self).setUp() + self.plot = PlotWindow() + self.plot.show() + self.qWaitForWindowExposed(self.plot) + self.mouseMove(self.plot, pos=(1, 1)) + self.qapp.processEvents() + self.qWait(100) + + def tearDown(self): + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + + super(TestPositionInfo, self).tearDown() + + def _test(self, positionWidget, converterNames, **kwargs): + """General test of PositionInfo. + + - Add it to a toolbar and + - Move mouse around the center of the PlotWindow. + """ + toolBar = qt.QToolBar() + self.plot.addToolBar(qt.Qt.BottomToolBarArea, toolBar) + + toolBar.addWidget(positionWidget) + + converters = positionWidget.getConverters() + self.assertEqual(len(converters), len(converterNames)) + for index, name in enumerate(converterNames): + self.assertEqual(converters[index][0], name) + + with TestLogging(PlotTools.__name__, **kwargs): + # Move mouse to center + self.mouseMove(self.plot) + self.mouseMove(self.plot, pos=(1, 1)) + self.qapp.processEvents() + self.qWait(100) + + def testDefaultConverters(self): + """Test PositionInfo with default converters""" + positionWidget = PlotTools.PositionInfo(plot=self.plot) + self._test(positionWidget, ('X', 'Y')) + + def testCustomConverters(self): + """Test PositionInfo with custom converters""" + converters = [ + ('Coords', lambda x, y: (int(x), int(y))), + ('Radius', lambda x, y: numpy.sqrt(x * x + y * y)), + ('Angle', lambda x, y: numpy.degrees(numpy.arctan2(y, x))) + ] + positionWidget = PlotTools.PositionInfo(plot=self.plot, + converters=converters) + self._test(positionWidget, ('Coords', 'Radius', 'Angle')) + + def testFailingConverters(self): + """Test PositionInfo with failing custom converters""" + def raiseException(x, y): + raise RuntimeError() + + positionWidget = PlotTools.PositionInfo( + plot=self.plot, + converters=[('Exception', raiseException)]) + self._test(positionWidget, ['Exception'], error=2) + + +class TestPixelIntensitiesHisto(TestCaseQt, ParametricTestCase): + """Tests for ProfileToolBar widget.""" + + def setUp(self): + super(TestPixelIntensitiesHisto, self).setUp() + self.image = numpy.random.rand(100, 100) + self.plotImage = Plot2D() + self.plotImage.getIntensityHistogramAction().setVisible(True) + + def tearDown(self): + del self.plotImage + super(TestPixelIntensitiesHisto, self).tearDown() + + def testShowAndHide(self): + """Simple test that the plot is showing and hiding when activating the + action""" + self.plotImage.addImage(self.image, origin=(0, 0), legend='sino') + self.plotImage.show() + + histoAction = self.plotImage.getIntensityHistogramAction() + + # test the pixel intensity diagram is showing + button = getQToolButtonFromAction(histoAction) + self.assertIsNot(button, None) + self.mouseMove(button) + self.mouseClick(button, qt.Qt.LeftButton) + self.qapp.processEvents() + self.assertTrue(histoAction.getHistogramPlotWidget().isVisible()) + + # test the pixel intensity diagram is hiding + self.qapp.setActiveWindow(self.plotImage) + self.qapp.processEvents() + self.mouseMove(button) + self.mouseClick(button, qt.Qt.LeftButton) + self.qapp.processEvents() + self.assertFalse(histoAction.getHistogramPlotWidget().isVisible()) + + def testImageFormatInput(self): + """Test multiple type as image input""" + typesToTest = [numpy.uint8, numpy.int8, numpy.int16, numpy.int32, + numpy.float32, numpy.float64] + self.plotImage.addImage(self.image, origin=(0, 0), legend='sino') + self.plotImage.show() + button = getQToolButtonFromAction( + self.plotImage.getIntensityHistogramAction()) + self.mouseMove(button) + self.mouseClick(button, qt.Qt.LeftButton) + self.qapp.processEvents() + for typeToTest in typesToTest: + with self.subTest(typeToTest=typeToTest): + self.plotImage.addImage(self.image.astype(typeToTest), + origin=(0, 0), legend='sino') + + +def suite(): + test_suite = unittest.TestSuite() + # test_suite.addTest(positionInfoTestSuite) + for testClass in (TestPositionInfo, TestPixelIntensitiesHisto): + test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase( + testClass)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot/test/testPlotWidget.py b/silx/gui/plot/test/testPlotWidget.py new file mode 100644 index 0000000..2de18a8 --- /dev/null +++ b/silx/gui/plot/test/testPlotWidget.py @@ -0,0 +1,967 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for PlotWidget""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "05/12/2016" + + +import unittest + +import numpy + +from silx.test.utils import ParametricTestCase +from silx.gui.test.utils import TestCaseQt + +from silx.gui import qt +from silx.gui.plot import PlotWidget + + +SIZE = 1024 +"""Size of the test image""" + +DATA_2D = numpy.arange(SIZE ** 2).reshape(SIZE, SIZE) +"""Image data set""" + + +class _PlotWidgetTest(TestCaseQt): + """Base class for tests of PlotWidget, not a TestCase in itself. + + plot attribute is the PlotWidget created for the test. + """ + + def setUp(self): + super(_PlotWidgetTest, self).setUp() + self.plot = PlotWidget() + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + def tearDown(self): + self.qapp.processEvents() + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + super(_PlotWidgetTest, self).tearDown() + + +class TestPlotWidget(_PlotWidgetTest, ParametricTestCase): + """Basic tests for PlotWidget""" + + def testShow(self): + """Most basic test""" + pass + + def testSetTitleLabels(self): + """Set title and axes labels""" + + title, xlabel, ylabel = 'the title', 'x label', 'y label' + self.plot.setGraphTitle(title) + self.plot.setGraphXLabel(xlabel) + self.plot.setGraphYLabel(ylabel) + self.qapp.processEvents() + + self.assertEqual(self.plot.getGraphTitle(), title) + self.assertEqual(self.plot.getGraphXLabel(), xlabel) + self.assertEqual(self.plot.getGraphYLabel(), ylabel) + + def testChangeLimitsWithAspectRatio(self): + def checkLimits(expectedXLim=None, expectedYLim=None, + expectedRatio=None): + xlim = self.plot.getGraphXLimits() + ylim = self.plot.getGraphYLimits() + ratio = abs(xlim[1] - xlim[0]) / abs(ylim[1] - ylim[0]) + + if expectedXLim is not None: + self.assertEqual(expectedXLim, xlim) + + if expectedYLim is not None: + self.assertEqual(expectedYLim, ylim) + + if expectedRatio is not None: + self.assertTrue( + numpy.allclose(expectedRatio, ratio, atol=0.01)) + + self.plot.setKeepDataAspectRatio() + self.qapp.processEvents() + xlim = self.plot.getGraphXLimits() + ylim = self.plot.getGraphYLimits() + defaultRatio = abs(xlim[1] - xlim[0]) / abs(ylim[1] - ylim[0]) + + self.plot.setGraphXLimits(1., 10.) + checkLimits(expectedXLim=(1., 10.), expectedRatio=defaultRatio) + self.qapp.processEvents() + checkLimits(expectedXLim=(1., 10.), expectedRatio=defaultRatio) + + self.plot.setGraphYLimits(1., 10.) + checkLimits(expectedYLim=(1., 10.), expectedRatio=defaultRatio) + self.qapp.processEvents() + checkLimits(expectedYLim=(1., 10.), expectedRatio=defaultRatio) + + +class TestPlotImage(_PlotWidgetTest, ParametricTestCase): + """Basic tests for addImage""" + + def setUp(self): + super(TestPlotImage, self).setUp() + + self.plot.setGraphYLabel('Rows') + self.plot.setGraphXLabel('Columns') + + def testPlotColormapTemperature(self): + self.plot.setGraphTitle('Temp. Linear') + + colormap = {'name': 'temperature', 'normalization': 'linear', + 'autoscale': True, 'vmin': 0.0, 'vmax': 1.0} + self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap) + + def testPlotColormapGray(self): + self.plot.setKeepDataAspectRatio(False) + self.plot.setGraphTitle('Gray Linear') + + colormap = {'name': 'gray', 'normalization': 'linear', + 'autoscale': True, 'vmin': 0.0, 'vmax': 1.0} + self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap) + + def testPlotColormapTemperatureLog(self): + self.plot.setGraphTitle('Temp. Log') + + colormap = {'name': 'temperature', 'normalization': 'log', + 'autoscale': True, 'vmin': 0.0, 'vmax': 1.0} + self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap) + + def testPlotRgbRgba(self): + self.plot.setKeepDataAspectRatio(False) + self.plot.setGraphTitle('RGB + RGBA') + + rgb = numpy.array( + (((0, 0, 0), (128, 0, 0), (255, 0, 0)), + ((0, 128, 0), (0, 128, 128), (0, 128, 256))), + dtype=numpy.uint8) + + self.plot.addImage(rgb, legend="rgb", + origin=(0, 0), scale=(10, 10), + replace=False, resetzoom=False) + + rgba = numpy.array( + (((0, 0, 0, .5), (.5, 0, 0, 1), (1, 0, 0, .5)), + ((0, .5, 0, 1), (0, .5, .5, 1), (0, 1, 1, .5))), + dtype=numpy.float32) + + self.plot.addImage(rgba, legend="rgba", + origin=(5, 5), scale=(10, 10), + replace=False, resetzoom=False) + + self.plot.resetZoom() + + def testPlotColormapCustom(self): + self.plot.setKeepDataAspectRatio(False) + self.plot.setGraphTitle('Custom colormap') + + colormap = {'name': None, 'normalization': 'linear', + 'autoscale': True, 'vmin': 0.0, 'vmax': 1.0, + 'colors': ((0., 0., 0.), (1., 0., 0.), + (0., 1., 0.), (0., 0., 1.))} + self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap, + replace=False, resetzoom=False) + + colormap = {'name': None, 'normalization': 'linear', + 'autoscale': True, 'vmin': 0.0, 'vmax': 1.0, + 'colors': numpy.array( + ((0, 0, 0, 0), (0, 0, 0, 128), + (128, 128, 128, 128), (255, 255, 255, 255)), + dtype=numpy.uint8)} + self.plot.addImage(DATA_2D, legend="image 2", colormap=colormap, + origin=(DATA_2D.shape[0], 0), + replace=False, resetzoom=False) + self.plot.resetZoom() + + def testImageOriginScale(self): + """Test of image with different origin and scale""" + self.plot.setGraphTitle('origin and scale') + + tests = [ # (origin, scale) + ((10, 20), (1, 1)), + ((10, 20), (-1, -1)), + ((-10, 20), (2, 1)), + ((10, -20), (-1, -2)), + (100, 2), + (-100, (1, 1)), + ((10, 20), 2), + ] + + for origin, scale in tests: + with self.subTest(origin=origin, scale=scale): + self.plot.addImage(DATA_2D, origin=origin, scale=scale) + + try: + ox, oy = origin + except TypeError: + ox, oy = origin, origin + try: + sx, sy = scale + except TypeError: + sx, sy = scale, scale + xbounds = ox, ox + DATA_2D.shape[1] * sx + ybounds = oy, oy + DATA_2D.shape[0] * sy + + # Check limits without aspect ratio + xmin, xmax = self.plot.getGraphXLimits() + ymin, ymax = self.plot.getGraphYLimits() + self.assertEqual(xmin, min(xbounds)) + self.assertEqual(xmax, max(xbounds)) + self.assertEqual(ymin, min(ybounds)) + self.assertEqual(ymax, max(ybounds)) + + # Check limits with aspect ratio + self.plot.setKeepDataAspectRatio(True) + xmin, xmax = self.plot.getGraphXLimits() + ymin, ymax = self.plot.getGraphYLimits() + self.assertTrue(xmin <= min(xbounds)) + self.assertTrue(xmax >= max(xbounds)) + self.assertTrue(ymin <= min(ybounds)) + self.assertTrue(ymax >= max(ybounds)) + + self.plot.setKeepDataAspectRatio(False) # Reset aspect ratio + self.plot.clear() + self.plot.resetZoom() + + +class TestPlotCurve(_PlotWidgetTest): + """Basic tests for addCurve.""" + + # Test data sets + xData = numpy.arange(1000) + yData = -500 + 100 * numpy.sin(xData) + xData2 = xData + 1000 + yData2 = xData - 1000 + 200 * numpy.random.random(1000) + + def setUp(self): + super(TestPlotCurve, self).setUp() + self.plot.setGraphTitle('Curve') + self.plot.setGraphYLabel('Rows') + self.plot.setGraphXLabel('Columns') + + self.plot.setActiveCurveHandling(False) + + def testPlotCurveColorFloat(self): + color = numpy.array(numpy.random.random(3 * 1000), + dtype=numpy.float32).reshape(1000, 3) + + self.plot.addCurve(self.xData, self.yData, + legend="curve 1", + replace=False, resetzoom=False, + color=color, + linestyle="", symbol="s") + self.plot.addCurve(self.xData2, self.yData2, + legend="curve 2", + replace=False, resetzoom=False, + color='green', linestyle="-", symbol='o') + self.plot.resetZoom() + + def testPlotCurveColorByte(self): + color = numpy.array(255 * numpy.random.random(3 * 1000), + dtype=numpy.uint8).reshape(1000, 3) + + self.plot.addCurve(self.xData, self.yData, + legend="curve 1", + replace=False, resetzoom=False, + color=color, + linestyle="", symbol="s") + self.plot.addCurve(self.xData2, self.yData2, + legend="curve 2", + replace=False, resetzoom=False, + color='green', linestyle="-", symbol='o') + self.plot.resetZoom() + + def testPlotCurveColors(self): + color = numpy.array(numpy.random.random(3 * 1000), + dtype=numpy.float32).reshape(1000, 3) + + self.plot.addCurve(self.xData, self.yData, + legend="curve 2", + replace=False, resetzoom=False, + color=color, linestyle="-", symbol='o') + self.plot.resetZoom() + + +class TestPlotMarker(_PlotWidgetTest): + """Basic tests for add*Marker""" + + def setUp(self): + super(TestPlotMarker, self).setUp() + self.plot.setGraphYLabel('Rows') + self.plot.setGraphXLabel('Columns') + + self.plot.setXAxisAutoScale(False) + self.plot.setYAxisAutoScale(False) + self.plot.setKeepDataAspectRatio(False) + self.plot.setLimits(0., 100., -100., 100.) + + def testPlotMarkerX(self): + self.plot.setGraphTitle('Markers X') + + markers = [ + (10., 'blue', False, False), + (20., 'red', False, False), + (40., 'green', True, False), + (60., 'gray', True, True), + (80., 'black', False, True), + ] + + for x, color, select, drag in markers: + name = str(x) + if select: + name += " sel." + if drag: + name += " drag" + self.plot.addXMarker(x, name, name, color, select, drag) + self.plot.resetZoom() + + def testPlotMarkerY(self): + self.plot.setGraphTitle('Markers Y') + + markers = [ + (-50., 'blue', False, False), + (-30., 'red', False, False), + (0., 'green', True, False), + (10., 'gray', True, True), + (80., 'black', False, True), + ] + + for y, color, select, drag in markers: + name = str(y) + if select: + name += " sel." + if drag: + name += " drag" + self.plot.addYMarker(y, name, name, color, select, drag) + self.plot.resetZoom() + + def testPlotMarkerPt(self): + self.plot.setGraphTitle('Markers Pt') + + markers = [ + (10., -50., 'blue', False, False), + (40., -30., 'red', False, False), + (50., 0., 'green', True, False), + (50., 20., 'gray', True, True), + (70., 50., 'black', False, True), + ] + for x, y, color, select, drag in markers: + name = "{0},{1}".format(x, y) + if select: + name += " sel." + if drag: + name += " drag" + self.plot.addMarker(x, y, name, name, color, select, drag) + + self.plot.resetZoom() + + def testPlotMarkerWithoutLegend(self): + self.plot.setGraphTitle('Markers without legend') + self.plot.setYAxisInverted(True) + + # Markers without legend + self.plot.addMarker(10, 10) + self.plot.addMarker(10, 20) + self.plot.addMarker(40, 50, text='test', symbol=None) + self.plot.addMarker(40, 50, text='test', symbol='+') + self.plot.addXMarker(25) + self.plot.addXMarker(35) + self.plot.addXMarker(45, text='test') + self.plot.addYMarker(55) + self.plot.addYMarker(65) + self.plot.addYMarker(75, text='test') + + self.plot.resetZoom() + + +# TestPlotItem ################################################################ + +class TestPlotItem(_PlotWidgetTest): + """Basic tests for addItem.""" + + # Polygon coordinates and color + polygons = [ # legend, x coords, y coords, color + ('triangle', numpy.array((10, 30, 50)), + numpy.array((55, 70, 55)), 'red'), + ('square', numpy.array((10, 10, 50, 50)), + numpy.array((10, 50, 50, 10)), 'green'), + ('star', numpy.array((60, 70, 80, 60, 80)), + numpy.array((25, 50, 25, 40, 40)), 'blue'), + ] + + # Rectangle coordinantes and color + rectangles = [ # legend, x coords, y coords, color + ('square 1', numpy.array((1., 10.)), + numpy.array((1., 10.)), 'red'), + ('square 2', numpy.array((10., 20.)), + numpy.array((10., 20.)), 'green'), + ('square 3', numpy.array((20., 30.)), + numpy.array((20., 30.)), 'blue'), + ('rect 1', numpy.array((1., 30.)), + numpy.array((35., 40.)), 'black'), + ('line h', numpy.array((1., 30.)), + numpy.array((45., 45.)), 'darkRed'), + ] + + def setUp(self): + super(TestPlotItem, self).setUp() + + self.plot.setGraphYLabel('Rows') + self.plot.setGraphXLabel('Columns') + self.plot.setXAxisAutoScale(False) + self.plot.setYAxisAutoScale(False) + self.plot.setKeepDataAspectRatio(False) + self.plot.setLimits(0., 100., -100., 100.) + + def testPlotItemPolygonFill(self): + self.plot.setGraphTitle('Item Fill') + + for legend, xList, yList, color in self.polygons: + self.plot.addItem(xList, yList, legend=legend, + replace=False, + shape="polygon", fill=True, color=color) + self.plot.resetZoom() + + def testPlotItemPolygonNoFill(self): + self.plot.setGraphTitle('Item No Fill') + + for legend, xList, yList, color in self.polygons: + self.plot.addItem(xList, yList, legend=legend, + replace=False, + shape="polygon", fill=False, color=color) + self.plot.resetZoom() + + def testPlotItemRectangleFill(self): + self.plot.setGraphTitle('Rectangle Fill') + + for legend, xList, yList, color in self.rectangles: + self.plot.addItem(xList, yList, legend=legend, + replace=False, + shape="rectangle", fill=True, color=color) + self.plot.resetZoom() + + def testPlotItemRectangleNoFill(self): + self.plot.setGraphTitle('Rectangle No Fill') + + for legend, xList, yList, color in self.rectangles: + self.plot.addItem(xList, yList, legend=legend, + replace=False, + shape="rectangle", fill=False, color=color) + self.plot.resetZoom() + + +class TestPlotActiveCurveImage(_PlotWidgetTest): + """Basic tests for active image handling""" + + def testActiveCurveAndLabels(self): + # Active curve handling off, no label change + self.plot.setActiveCurveHandling(False) + self.plot.setGraphXLabel('XLabel') + self.plot.setGraphYLabel('YLabel') + self.plot.addCurve((1, 2), (1, 2)) + self.assertEqual(self.plot.getGraphXLabel(), 'XLabel') + self.assertEqual(self.plot.getGraphYLabel(), 'YLabel') + + self.plot.addCurve((1, 2), (2, 3), xlabel='x1', ylabel='y1') + self.assertEqual(self.plot.getGraphXLabel(), 'XLabel') + self.assertEqual(self.plot.getGraphYLabel(), 'YLabel') + + self.plot.clear() + self.assertEqual(self.plot.getGraphXLabel(), 'XLabel') + self.assertEqual(self.plot.getGraphYLabel(), 'YLabel') + + # Active curve handling on, label changes + self.plot.setActiveCurveHandling(True) + self.plot.setGraphXLabel('XLabel') + self.plot.setGraphYLabel('YLabel') + + # labels changed as active curve + self.plot.addCurve((1, 2), (1, 2), legend='1', + xlabel='x1', ylabel='y1') + self.assertEqual(self.plot.getGraphXLabel(), 'x1') + self.assertEqual(self.plot.getGraphYLabel(), 'y1') + + # labels not changed as not active curve + self.plot.addCurve((1, 2), (2, 3), legend='2') + self.assertEqual(self.plot.getGraphXLabel(), 'x1') + self.assertEqual(self.plot.getGraphYLabel(), 'y1') + + # labels changed + self.plot.setActiveCurve('2') + self.assertEqual(self.plot.getGraphXLabel(), 'XLabel') + self.assertEqual(self.plot.getGraphYLabel(), 'YLabel') + + self.plot.setActiveCurve('1') + self.assertEqual(self.plot.getGraphXLabel(), 'x1') + self.assertEqual(self.plot.getGraphYLabel(), 'y1') + + self.plot.clear() + self.assertEqual(self.plot.getGraphXLabel(), 'XLabel') + self.assertEqual(self.plot.getGraphYLabel(), 'YLabel') + + def testActiveImageAndLabels(self): + # Active image handling always on, no API for toggling it + self.plot.setGraphXLabel('XLabel') + self.plot.setGraphYLabel('YLabel') + + # labels changed as active curve + self.plot.addImage(numpy.arange(100).reshape(10, 10), replace=False, + legend='1', xlabel='x1', ylabel='y1') + self.assertEqual(self.plot.getGraphXLabel(), 'x1') + self.assertEqual(self.plot.getGraphYLabel(), 'y1') + + # labels not changed as not active curve + self.plot.addImage(numpy.arange(100).reshape(10, 10), replace=False, + legend='2') + self.assertEqual(self.plot.getGraphXLabel(), 'x1') + self.assertEqual(self.plot.getGraphYLabel(), 'y1') + + # labels changed + self.plot.setActiveImage('2') + self.assertEqual(self.plot.getGraphXLabel(), 'XLabel') + self.assertEqual(self.plot.getGraphYLabel(), 'YLabel') + + self.plot.setActiveImage('1') + self.assertEqual(self.plot.getGraphXLabel(), 'x1') + self.assertEqual(self.plot.getGraphYLabel(), 'y1') + + self.plot.clear() + self.assertEqual(self.plot.getGraphXLabel(), 'XLabel') + self.assertEqual(self.plot.getGraphYLabel(), 'YLabel') + + +############################################################################## +# Log +############################################################################## + +class TestPlotEmptyLog(_PlotWidgetTest): + """Basic tests for log plot""" + def testEmptyPlotTitleLabelsLog(self): + self.plot.setGraphTitle('Empty Log Log') + self.plot.setGraphXLabel('X') + self.plot.setGraphYLabel('Y') + self.plot.setXAxisLogarithmic(True) + self.plot.setYAxisLogarithmic(True) + self.plot.resetZoom() + + +class TestPlotCurveLog(_PlotWidgetTest, ParametricTestCase): + """Basic tests for addCurve with log scale axes""" + + # Test data + xData = numpy.arange(1000) + 1 + yData = xData ** 2 + + def _setLabels(self): + self.plot.setGraphXLabel('X') + self.plot.setGraphYLabel('X * X') + + def testPlotCurveLogX(self): + self._setLabels() + self.plot.setXAxisLogarithmic(True) + self.plot.setGraphTitle('Curve X: Log Y: Linear') + + self.plot.addCurve(self.xData, self.yData, + legend="curve", + replace=False, resetzoom=True, + color='green', linestyle="-", symbol='o') + + def testPlotCurveLogY(self): + self._setLabels() + self.plot.setYAxisLogarithmic(True) + + self.plot.setGraphTitle('Curve X: Linear Y: Log') + + self.plot.addCurve(self.xData, self.yData, + legend="curve", + replace=False, resetzoom=True, + color='green', linestyle="-", symbol='o') + + def testPlotCurveLogXY(self): + self._setLabels() + self.plot.setXAxisLogarithmic(True) + self.plot.setYAxisLogarithmic(True) + + self.plot.setGraphTitle('Curve X: Log Y: Log') + + self.plot.addCurve(self.xData, self.yData, + legend="curve", + replace=False, resetzoom=True, + color='green', linestyle="-", symbol='o') + + def testPlotCurveErrorLogXY(self): + self.plot.setXAxisLogarithmic(True) + self.plot.setYAxisLogarithmic(True) + + # Every second error leads to negative number + errors = numpy.ones_like(self.xData) + errors[::2] = self.xData[::2] + 1 + + tests = [ # name, xerror, yerror + ('xerror=3', 3, None), + ('xerror=N array', errors, None), + ('xerror=Nx1 array', errors.reshape(len(errors), 1), None), + ('xerror=2xN array', numpy.array((errors, errors)), None), + ('yerror=6', None, 6), + ('yerror=N array', None, errors ** 2), + ('yerror=Nx1 array', None, (errors ** 2).reshape(len(errors), 1)), + ('yerror=2xN array', None, numpy.array((errors, errors)) ** 2), + ] + + for name, xError, yError in tests: + with self.subTest(name): + self.plot.setGraphTitle(name) + self.plot.addCurve(self.xData, self.yData, + legend=name, + xerror=xError, yerror=yError, + replace=False, resetzoom=True, + color='green', linestyle="-", symbol='o') + + self.qapp.processEvents() + + self.plot.clear() + self.plot.resetZoom() + self.qapp.processEvents() + + def testPlotCurveToggleLog(self): + """Add a curve with negative data and toggle log axis""" + arange = numpy.arange(1000) + 1 + tests = [ # name, xData, yData + ('x>0, some negative y', arange, arange - 500), + ('x>0, y<0', arange, -arange), + ('some negative x, y>0', arange - 500, arange), + ('x<0, y>0', -arange, arange), + ('some negative x and y', arange - 500, arange - 500), + ('x<0, y<0', -arange, -arange), + ] + + for name, xData, yData in tests: + with self.subTest(name): + self.plot.addCurve(xData, yData, resetzoom=True) + self.qapp.processEvents() + + # no log axis + xLim = self.plot.getGraphXLimits() + self.assertEqual(xLim, (min(xData), max(xData))) + yLim = self.plot.getGraphYLimits() + self.assertEqual(yLim, (min(yData), max(yData))) + + # x axis log + self.plot.setXAxisLogarithmic(True) + self.qapp.processEvents() + + xLim = self.plot.getGraphXLimits() + yLim = self.plot.getGraphYLimits() + positives = xData > 0 + if numpy.any(positives): + self.assertTrue(numpy.allclose( + xLim, (min(xData[positives]), max(xData[positives])))) + self.assertEqual( + yLim, (min(yData[positives]), max(yData[positives]))) + else: # No positive x in the curve + self.assertEqual(xLim, (1., 100.)) + self.assertEqual(yLim, (1., 100.)) + + # x axis and y axis log + self.plot.setYAxisLogarithmic(True) + self.qapp.processEvents() + + xLim = self.plot.getGraphXLimits() + yLim = self.plot.getGraphYLimits() + positives = numpy.logical_and(xData > 0, yData > 0) + if numpy.any(positives): + self.assertTrue(numpy.allclose( + xLim, (min(xData[positives]), max(xData[positives])))) + self.assertTrue(numpy.allclose( + yLim, (min(yData[positives]), max(yData[positives])))) + else: # No positive x and y in the curve + self.assertEqual(xLim, (1., 100.)) + self.assertEqual(yLim, (1., 100.)) + + # y axis log + self.plot.setXAxisLogarithmic(False) + self.qapp.processEvents() + + xLim = self.plot.getGraphXLimits() + yLim = self.plot.getGraphYLimits() + positives = yData > 0 + if numpy.any(positives): + self.assertEqual( + xLim, (min(xData[positives]), max(xData[positives]))) + self.assertTrue(numpy.allclose( + yLim, (min(yData[positives]), max(yData[positives])))) + else: # No positive y in the curve + self.assertEqual(xLim, (1., 100.)) + self.assertEqual(yLim, (1., 100.)) + + # no log axis + self.plot.setYAxisLogarithmic(False) + self.qapp.processEvents() + + xLim = self.plot.getGraphXLimits() + self.assertEqual(xLim, (min(xData), max(xData))) + yLim = self.plot.getGraphYLimits() + self.assertEqual(yLim, (min(yData), max(yData))) + + self.plot.clear() + self.plot.resetZoom() + self.qapp.processEvents() + + +class TestPlotImageLog(_PlotWidgetTest): + """Basic tests for addImage with log scale axes.""" + + def setUp(self): + super(TestPlotImageLog, self).setUp() + + self.plot.setGraphXLabel('Columns') + self.plot.setGraphYLabel('Rows') + + def testPlotColormapGrayLogX(self): + self.plot.setXAxisLogarithmic(True) + self.plot.setGraphTitle('CMap X: Log Y: Linear') + + colormap = {'name': 'gray', 'normalization': 'linear', + 'autoscale': True, 'vmin': 0.0, 'vmax': 1.0} + self.plot.addImage(DATA_2D, legend="image 1", + origin=(1., 1.), scale=(1., 1.), + replace=False, resetzoom=False, colormap=colormap) + self.plot.resetZoom() + + def testPlotColormapGrayLogY(self): + self.plot.setYAxisLogarithmic(True) + self.plot.setGraphTitle('CMap X: Linear Y: Log') + + colormap = {'name': 'gray', 'normalization': 'linear', + 'autoscale': True, 'vmin': 0.0, 'vmax': 1.0} + self.plot.addImage(DATA_2D, legend="image 1", + origin=(1., 1.), scale=(1., 1.), + replace=False, resetzoom=False, colormap=colormap) + self.plot.resetZoom() + + def testPlotColormapGrayLogXY(self): + self.plot.setXAxisLogarithmic(True) + self.plot.setYAxisLogarithmic(True) + self.plot.setGraphTitle('CMap X: Log Y: Log') + + colormap = {'name': 'gray', 'normalization': 'linear', + 'autoscale': True, 'vmin': 0.0, 'vmax': 1.0} + self.plot.addImage(DATA_2D, legend="image 1", + origin=(1., 1.), scale=(1., 1.), + replace=False, resetzoom=False, colormap=colormap) + self.plot.resetZoom() + + def testPlotRgbRgbaLogXY(self): + self.plot.setXAxisLogarithmic(True) + self.plot.setYAxisLogarithmic(True) + self.plot.setGraphTitle('RGB + RGBA X: Log Y: Log') + + rgb = numpy.array( + (((0, 0, 0), (128, 0, 0), (255, 0, 0)), + ((0, 128, 0), (0, 128, 128), (0, 128, 256))), + dtype=numpy.uint8) + + self.plot.addImage(rgb, legend="rgb", + origin=(1, 1), scale=(10, 10), + replace=False, resetzoom=False) + + rgba = numpy.array( + (((0, 0, 0, .5), (.5, 0, 0, 1), (1, 0, 0, .5)), + ((0, .5, 0, 1), (0, .5, .5, 1), (0, 1, 1, .5))), + dtype=numpy.float32) + + self.plot.addImage(rgba, legend="rgba", + origin=(5., 5.), scale=(10., 10.), + replace=False, resetzoom=False) + self.plot.resetZoom() + + +class TestPlotMarkerLog(_PlotWidgetTest): + """Basic tests for markers on log scales""" + + # Test marker parameters + markers = [ # x, y, color, selectable, draggable + (10., 10., 'blue', False, False), + (20., 20., 'red', False, False), + (40., 100., 'green', True, False), + (40., 500., 'gray', True, True), + (60., 800., 'black', False, True), + ] + + def setUp(self): + super(TestPlotMarkerLog, self).setUp() + + self.plot.setGraphYLabel('Rows') + self.plot.setGraphXLabel('Columns') + self.plot.setXAxisAutoScale(False) + self.plot.setYAxisAutoScale(False) + self.plot.setKeepDataAspectRatio(False) + self.plot.setLimits(1., 100., 1., 1000.) + self.plot.setXAxisLogarithmic(True) + self.plot.setYAxisLogarithmic(True) + + def testPlotMarkerXLog(self): + self.plot.setGraphTitle('Markers X, Log axes') + + for x, _, color, select, drag in self.markers: + name = str(x) + if select: + name += " sel." + if drag: + name += " drag" + self.plot.addXMarker(x, name, name, color, select, drag) + self.plot.resetZoom() + + def testPlotMarkerYLog(self): + self.plot.setGraphTitle('Markers Y, Log axes') + + for _, y, color, select, drag in self.markers: + name = str(y) + if select: + name += " sel." + if drag: + name += " drag" + self.plot.addYMarker(y, name, name, color, select, drag) + self.plot.resetZoom() + + def testPlotMarkerPtLog(self): + self.plot.setGraphTitle('Markers Pt, Log axes') + + for x, y, color, select, drag in self.markers: + name = "{0},{1}".format(x, y) + if select: + name += " sel." + if drag: + name += " drag" + self.plot.addMarker(x, y, name, name, color, select, drag) + self.plot.resetZoom() + + +class TestPlotItemLog(_PlotWidgetTest): + """Basic tests for items with log scale axes""" + + # Polygon coordinates and color + polygons = [ # legend, x coords, y coords, color + ('triangle', numpy.array((10, 30, 50)), + numpy.array((55, 70, 55)), 'red'), + ('square', numpy.array((10, 10, 50, 50)), + numpy.array((10, 50, 50, 10)), 'green'), + ('star', numpy.array((60, 70, 80, 60, 80)), + numpy.array((25, 50, 25, 40, 40)), 'blue'), + ] + + # Rectangle coordinantes and color + rectangles = [ # legend, x coords, y coords, color + ('square 1', numpy.array((1., 10.)), + numpy.array((1., 10.)), 'red'), + ('square 2', numpy.array((10., 20.)), + numpy.array((10., 20.)), 'green'), + ('square 3', numpy.array((20., 30.)), + numpy.array((20., 30.)), 'blue'), + ('rect 1', numpy.array((1., 30.)), + numpy.array((35., 40.)), 'black'), + ('line h', numpy.array((1., 30.)), + numpy.array((45., 45.)), 'darkRed'), + ] + + def setUp(self): + super(TestPlotItemLog, self).setUp() + + self.plot.setGraphYLabel('Rows') + self.plot.setGraphXLabel('Columns') + self.plot.setXAxisAutoScale(False) + self.plot.setYAxisAutoScale(False) + self.plot.setKeepDataAspectRatio(False) + self.plot.setLimits(1., 100., 1., 100.) + self.plot.setXAxisLogarithmic(True) + self.plot.setYAxisLogarithmic(True) + + def testPlotItemPolygonLogFill(self): + self.plot.setGraphTitle('Item Fill Log') + + for legend, xList, yList, color in self.polygons: + self.plot.addItem(xList, yList, legend=legend, + replace=False, + shape="polygon", fill=True, color=color) + self.plot.resetZoom() + + def testPlotItemPolygonLogNoFill(self): + self.plot.setGraphTitle('Item No Fill Log') + + for legend, xList, yList, color in self.polygons: + self.plot.addItem(xList, yList, legend=legend, + replace=False, + shape="polygon", fill=False, color=color) + self.plot.resetZoom() + + def testPlotItemRectangleLogFill(self): + self.plot.setGraphTitle('Rectangle Fill Log') + + for legend, xList, yList, color in self.rectangles: + self.plot.addItem(xList, yList, legend=legend, + replace=False, + shape="rectangle", fill=True, color=color) + self.plot.resetZoom() + + def testPlotItemRectangleLogNoFill(self): + self.plot.setGraphTitle('Rectangle No Fill Log') + + for legend, xList, yList, color in self.rectangles: + self.plot.addItem(xList, yList, legend=legend, + replace=False, + shape="rectangle", fill=False, color=color) + self.plot.resetZoom() + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotWidget)) + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotImage)) + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotCurve)) + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotMarker)) + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotItem)) + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotEmptyLog)) + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotCurveLog)) + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotImageLog)) + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotMarkerLog)) + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotItemLog)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot/test/testPlotWindow.py b/silx/gui/plot/test/testPlotWindow.py new file mode 100644 index 0000000..5afd53a --- /dev/null +++ b/silx/gui/plot/test/testPlotWindow.py @@ -0,0 +1,138 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for PlotWindow""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "05/12/2016" + + +import doctest +import unittest + +from silx.gui.test.utils import TestCaseQt, getQToolButtonFromAction + +from silx.gui import qt +from silx.gui.plot import PlotWindow + + +# Test of the docstrings # + +# Makes sure a QApplication exists +_qapp = qt.QApplication.instance() or qt.QApplication([]) + + +def _tearDownQt(docTest): + """Tear down to use for test from docstring. + + Checks that plt widget is displayed + """ + _qapp.processEvents() + for obj in docTest.globs.values(): + if isinstance(obj, PlotWindow): + # Commented out as it takes too long + # qWaitForWindowExposedAndActivate(obj) + obj.setAttribute(qt.Qt.WA_DeleteOnClose) + obj.close() + del obj + + +plotWindowDocTestSuite = doctest.DocTestSuite('silx.gui.plot.PlotWindow', + tearDown=_tearDownQt) +"""Test suite of tests from the module's docstrings.""" + + +class TestPlotWindow(TestCaseQt): + """Base class for tests of PlotWindow.""" + + def setUp(self): + super(TestPlotWindow, self).setUp() + self.plot = PlotWindow() + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + def tearDown(self): + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + super(TestPlotWindow, self).tearDown() + + def testActions(self): + """Test the actions QToolButtons""" + self.plot.setLimits(1, 100, 1, 100) + + checkList = [ # QAction, Plot state getter + (self.plot.xAxisAutoScaleAction, self.plot.isXAxisAutoScale), + (self.plot.yAxisAutoScaleAction, self.plot.isYAxisAutoScale), + (self.plot.xAxisLogarithmicAction, self.plot.isXAxisLogarithmic), + (self.plot.yAxisLogarithmicAction, self.plot.isYAxisLogarithmic), + (self.plot.gridAction, self.plot.getGraphGrid), + ] + + for action, getter in checkList: + self.mouseMove(self.plot) + initialState = getter() + toolButton = getQToolButtonFromAction(action) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + self.assertNotEqual(getter(), initialState, + msg='"%s" state not changed' % action.text()) + + self.mouseClick(toolButton, qt.Qt.LeftButton) + self.assertEqual(getter(), initialState, + msg='"%s" state not changed' % action.text()) + + # Trigger a zoom reset + self.mouseMove(self.plot) + resetZoomAction = self.plot.resetZoomAction + toolButton = getQToolButtonFromAction(resetZoomAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + + def testToolAspectRatio(self): + self.plot.toolBar() + self.plot.keepDataAspectRatioButton.keepDataAspectRatio() + self.assertTrue(self.plot.isKeepDataAspectRatio()) + self.plot.keepDataAspectRatioButton.dontKeepDataAspectRatio() + self.assertFalse(self.plot.isKeepDataAspectRatio()) + + def testToolYAxisOrigin(self): + self.plot.toolBar() + self.plot.yAxisInvertedButton.setYAxisUpward() + self.assertFalse(self.plot.isYAxisInverted()) + self.plot.yAxisInvertedButton.setYAxisDownward() + self.assertTrue(self.plot.isYAxisInverted()) + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTest(plotWindowDocTestSuite) + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotWindow)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot/test/testProfile.py b/silx/gui/plot/test/testProfile.py new file mode 100644 index 0000000..43d3329 --- /dev/null +++ b/silx/gui/plot/test/testProfile.py @@ -0,0 +1,183 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for Profile""" + +__authors__ = ["T. Vincent", "P. Knobel"] +__license__ = "MIT" +__date__ = "23/02/2017" + +import numpy +import unittest + +from silx.test.utils import ParametricTestCase +from silx.gui.test.utils import ( + TestCaseQt, getQToolButtonFromAction) +from silx.gui import qt +from silx.gui.plot import PlotWindow, Plot1D, Plot2D, Profile +from silx.gui.plot.StackView import StackView + + +# Makes sure a QApplication exists +_qapp = qt.QApplication.instance() or qt.QApplication([]) + + +class TestProfileToolBar(TestCaseQt, ParametricTestCase): + """Tests for ProfileToolBar widget.""" + + def setUp(self): + super(TestProfileToolBar, self).setUp() + profileWindow = PlotWindow() + self.plot = PlotWindow() + self.toolBar = Profile.ProfileToolBar( + plot=self.plot, profileWindow=profileWindow) + self.plot.addToolBar(self.toolBar) + + self.plot.show() + self.qWaitForWindowExposed(self.plot) + profileWindow.show() + self.qWaitForWindowExposed(profileWindow) + + self.mouseMove(self.plot) # Move to center + self.qapp.processEvents() + + def tearDown(self): + self.qapp.processEvents() + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + del self.toolBar + + super(TestProfileToolBar, self).tearDown() + + def testAlignedProfile(self): + """Test horizontal and vertical profile, without and with image""" + # Use Plot backend widget to submit mouse events + widget = self.plot.getWidgetHandle() + + # 2 positions to use for mouse events + pos1 = widget.width() * 0.4, widget.height() * 0.4 + pos2 = widget.width() * 0.6, widget.height() * 0.6 + + for action in (self.toolBar.hLineAction, self.toolBar.vLineAction): + with self.subTest(mode=action.text()): + # Trigger tool button for mode + toolButton = getQToolButtonFromAction(action) + self.assertIsNot(toolButton, None) + self.mouseMove(toolButton) + self.mouseClick(toolButton, qt.Qt.LeftButton) + + # Without image + self.mouseMove(widget, pos=pos1) + self.mouseClick(widget, qt.Qt.LeftButton, pos=pos1) + + # with image + self.plot.addImage(numpy.arange(100 * 100).reshape(100, -1)) + self.mousePress(widget, qt.Qt.LeftButton, pos=pos1) + self.mouseMove(widget, pos=pos2) + self.mouseRelease(widget, qt.Qt.LeftButton, pos=pos2) + + self.mouseMove(widget) + self.mouseClick(widget, qt.Qt.LeftButton) + + def testDiagonalProfile(self): + """Test diagonal profile, without and with image""" + # Use Plot backend widget to submit mouse events + widget = self.plot.getWidgetHandle() + + # 2 positions to use for mouse events + pos1 = widget.width() * 0.4, widget.height() * 0.4 + pos2 = widget.width() * 0.6, widget.height() * 0.6 + + # Trigger tool button for diagonal profile mode + toolButton = getQToolButtonFromAction(self.toolBar.lineAction) + self.assertIsNot(toolButton, None) + self.mouseMove(toolButton) + self.mouseClick(toolButton, qt.Qt.LeftButton) + + for image in (False, True): + with self.subTest(image=image): + if image: + self.plot.addImage(numpy.arange(100 * 100).reshape(100, -1)) + + self.mouseMove(widget, pos=pos1) + self.mousePress(widget, qt.Qt.LeftButton, pos=pos1) + self.mouseMove(widget, pos=pos2) + self.mouseRelease(widget, qt.Qt.LeftButton, pos=pos2) + + self.plot.clear() + + +class TestGetProfilePlot(TestCaseQt): + + def testProfile1D(self): + plot = Plot2D() + plot.show() + self.qWaitForWindowExposed(plot) + plot.addImage([[0, 1], [2, 3]]) + self.assertIsInstance(plot.getProfileToolbar().getProfileMainWindow(), + qt.QMainWindow) + self.assertIsInstance(plot.getProfilePlot(), + Plot1D) + plot.setAttribute(qt.Qt.WA_DeleteOnClose) + plot.close() + del plot + + def testProfile2D(self): + """Test that the profile plot associated to a stack view is either a + Plot1D or a plot 2D instance.""" + plot = StackView() + plot.show() + self.qWaitForWindowExposed(plot) + + plot.setStack(numpy.array([[[0, 1], [2, 3]], + [[4, 5], [6, 7]]])) + + self.assertIsInstance(plot.getProfileToolbar().getProfileMainWindow(), + qt.QMainWindow) + + # plot.getProfileToolbar().profile3dAction.computeProfileIn2D() # default + + self.assertIsInstance(plot.getProfileToolbar().getProfilePlot(), + Plot2D) + plot.getProfileToolbar().profile3dAction.computeProfileIn1D() + self.assertIsInstance(plot.getProfileToolbar().getProfilePlot(), + Plot1D) + + plot.setAttribute(qt.Qt.WA_DeleteOnClose) + plot.close() + del plot + + +def suite(): + test_suite = unittest.TestSuite() + # test_suite.addTest(positionInfoTestSuite) + for testClass in (TestProfileToolBar, TestGetProfilePlot): + test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase( + testClass)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot/test/testScatterMaskToolsWidget.py b/silx/gui/plot/test/testScatterMaskToolsWidget.py new file mode 100644 index 0000000..8b5f2ad --- /dev/null +++ b/silx/gui/plot/test/testScatterMaskToolsWidget.py @@ -0,0 +1,313 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for MaskToolsWidget""" + +__authors__ = ["T. Vincent", "P. Knobel"] +__license__ = "MIT" +__date__ = "10/07/2017" + + +import logging +import os.path +import unittest + +import numpy + +from silx.gui import qt +from silx.test.utils import temp_dir, ParametricTestCase +from silx.gui.test.utils import TestCaseQt, getQToolButtonFromAction +from silx.gui.plot import PlotWindow, ScatterMaskToolsWidget + +try: + import fabio +except ImportError: + fabio = None + + +logging.basicConfig() +_logger = logging.getLogger(__name__) + + +class TestScatterMaskToolsWidget(TestCaseQt, ParametricTestCase): + """Basic test for MaskToolsWidget""" + + def setUp(self): + super(TestScatterMaskToolsWidget, self).setUp() + self.plot = PlotWindow() + + self.widget = ScatterMaskToolsWidget.ScatterMaskToolsDockWidget( + plot=self.plot, name='TEST') + self.plot.addDockWidget(qt.Qt.BottomDockWidgetArea, self.widget) + + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + self.maskWidget = self.widget.widget() + + def tearDown(self): + del self.maskWidget + del self.widget + + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plot.close() + del self.plot + + super(TestScatterMaskToolsWidget, self).tearDown() + + def testEmptyPlot(self): + """Empty plot, display MaskToolsDockWidget, toggle multiple masks""" + self.maskWidget.setMultipleMasks('single') + self.qapp.processEvents() + + self.maskWidget.setMultipleMasks('exclusive') + self.qapp.processEvents() + + def _drag(self): + """Drag from plot center to offset position""" + plot = self.plot.centralWidget() + xCenter, yCenter = plot.width() // 2, plot.height() // 2 + offset = min(plot.width(), plot.height()) // 10 + + pos0 = xCenter, yCenter + pos1 = xCenter + offset, yCenter + offset + + self.mouseMove(plot, pos=pos0) + self.mousePress(plot, qt.Qt.LeftButton, pos=pos0) + self.mouseMove(plot, pos=pos1) + self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos1) + + def _drawPolygon(self): + """Draw a star polygon in the plot""" + plot = self.plot.centralWidget() + x, y = plot.width() // 2, plot.height() // 2 + offset = min(plot.width(), plot.height()) // 10 + + star = [(x, y + offset), + (x - offset, y - offset), + (x + offset, y), + (x - offset, y), + (x + offset, y - offset)] + + for pos in star: + self.mouseMove(plot, pos=pos) + btn = qt.Qt.LeftButton if pos != star[-1] else qt.Qt.RightButton + self.mouseClick(plot, btn, pos=pos) + + def _drawPencil(self): + """Draw a star polygon in the plot""" + plot = self.plot.centralWidget() + x, y = plot.width() // 2, plot.height() // 2 + offset = min(plot.width(), plot.height()) // 10 + + star = [(x, y + offset), + (x - offset, y - offset), + (x + offset, y), + (x - offset, y), + (x + offset, y - offset)] + + self.mouseMove(plot, pos=star[0]) + self.mousePress(plot, qt.Qt.LeftButton, pos=star[0]) + for pos in star: + self.mouseMove(plot, pos=pos) + self.mouseRelease( + plot, qt.Qt.LeftButton, pos=star[-1]) + + def testWithAScatter(self): + """Plot with a Scatter: test MaskToolsWidget interactions""" + + # Add and remove a scatter (this should enable/disable GUI + change mask) + self.plot.addScatter( + x=numpy.arange(256), + y=numpy.arange(256), + value=numpy.random.random(256), + legend='test') + self.plot._setActiveItem(kind="scatter", legend="test") + self.qapp.processEvents() + + self.plot.remove('test', kind='scatter') + self.qapp.processEvents() + + self.plot.addScatter( + x=numpy.arange(1000), + y=1000 * (numpy.arange(1000) % 20), + value=numpy.random.random(1000), + legend='test') + self.plot._setActiveItem(kind="scatter", legend="test") + self.plot.resetZoom() + self.qapp.processEvents() + + # Test draw rectangle # + toolButton = getQToolButtonFromAction(self.maskWidget.rectAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + + # mask + self.maskWidget.maskStateGroup.button(1).click() + self.qapp.processEvents() + self._drag() + + self.assertFalse( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + + # unmask same region + self.maskWidget.maskStateGroup.button(0).click() + self.qapp.processEvents() + self._drag() + self.assertTrue( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + + # Test draw polygon # + toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + + # mask + self.maskWidget.maskStateGroup.button(1).click() + self.qapp.processEvents() + self._drawPolygon() + self.assertFalse( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + + # unmask same region + self.maskWidget.maskStateGroup.button(0).click() + self.qapp.processEvents() + self._drawPolygon() + self.assertTrue( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + + # Test draw pencil # + toolButton = getQToolButtonFromAction(self.maskWidget.pencilAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + + self.maskWidget.pencilSpinBox.setValue(10) + self.qapp.processEvents() + + # mask + self.maskWidget.maskStateGroup.button(1).click() + self.qapp.processEvents() + self._drawPencil() + self.assertFalse( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + + # unmask same region + self.maskWidget.maskStateGroup.button(0).click() + self.qapp.processEvents() + self._drawPencil() + self.assertTrue( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + + # Test no draw tool # + toolButton = getQToolButtonFromAction(self.maskWidget.browseAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + + self.plot.clear() + + def __loadSave(self, file_format): + self.plot.addScatter( + x=numpy.arange(256), + y=25 * (numpy.arange(256) % 10), + value=numpy.random.random(256), + legend='test') + self.plot._setActiveItem(kind="scatter", legend="test") + self.plot.resetZoom() + self.qapp.processEvents() + + # Draw a polygon mask + toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + self._drawPolygon() + + ref_mask = self.maskWidget.getSelectionMask() + self.assertFalse(numpy.all(numpy.equal(ref_mask, 0))) + + with temp_dir() as tmp: + mask_filename = os.path.join(tmp, 'mask.' + file_format) + self.maskWidget.save(mask_filename, file_format) + + self.maskWidget.resetSelectionMask() + self.assertTrue( + numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))) + + self.maskWidget.load(mask_filename) + self.assertTrue(numpy.all(numpy.equal( + self.maskWidget.getSelectionMask(), ref_mask))) + + def testLoadSaveNpy(self): + self.__loadSave("npy") + + def testLoadSaveCsv(self): + self.__loadSave("csv") + + def testSigMaskChangedEmitted(self): + self.qapp.processEvents() + self.plot.addScatter( + x=numpy.arange(1000), + y=1000 * (numpy.arange(1000) % 20), + value=numpy.ones((1000,)), + legend='test') + self.plot._setActiveItem(kind="scatter", legend="test") + self.plot.resetZoom() + self.qapp.processEvents() + + self.plot.remove('test', kind='scatter') + self.qapp.processEvents() + + self.plot.addScatter( + x=numpy.arange(1000), + y=1000 * (numpy.arange(1000) % 20), + value=numpy.random.random(1000), + legend='test') + + l = [] + + def slot(): + l.append(1) + + self.maskWidget.sigMaskChanged.connect(slot) + + # rectangle mask + toolButton = getQToolButtonFromAction(self.maskWidget.rectAction) + self.assertIsNot(toolButton, None) + self.mouseClick(toolButton, qt.Qt.LeftButton) + self.maskWidget.maskStateGroup.button(1).click() + self.qapp.processEvents() + self._drag() + + self.assertGreater(len(l), 0) + + +def suite(): + test_suite = unittest.TestSuite() + for TestClass in (TestScatterMaskToolsWidget,): + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestClass)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot/test/testStackView.py b/silx/gui/plot/test/testStackView.py new file mode 100644 index 0000000..69584cd --- /dev/null +++ b/silx/gui/plot/test/testStackView.py @@ -0,0 +1,209 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for StackView""" + +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "20/03/2017" + + +import unittest +import numpy + +from silx.gui.test.utils import TestCaseQt + +from silx.gui import qt +from silx.gui.plot import StackView +from silx.gui.plot.StackView import StackViewMainWindow + +from silx.utils.array_like import ListOfImages + + +# Makes sure a QApplication exists +_qapp = qt.QApplication.instance() or qt.QApplication([]) + + +class TestStackView(TestCaseQt): + """Base class for tests of StackView.""" + + def setUp(self): + super(TestStackView, self).setUp() + self.stackview = StackView() + self.stackview.show() + self.qWaitForWindowExposed(self.stackview) + self.mystack = numpy.fromfunction( + lambda i, j, k: numpy.sin(i/15.) + numpy.cos(j/4.) + 2 * numpy.sin(k/6.), + (10, 20, 30) + ) + + def tearDown(self): + self.stackview.setAttribute(qt.Qt.WA_DeleteOnClose) + self.stackview.close() + del self.stackview + super(TestStackView, self).tearDown() + + def testSetStack(self): + self.stackview.setStack(self.mystack) + self.stackview.setColormap("viridis", autoscale=True) + my_trans_stack, params = self.stackview.getStack() + self.assertEqual(my_trans_stack.shape, self.mystack.shape) + self.assertTrue(numpy.array_equal(self.mystack, + my_trans_stack)) + self.assertEqual(params["colormap"]["name"], + "viridis") + + def testSetStackPerspective(self): + self.stackview.setStack(self.mystack, perspective=1) + # my_orig_stack, params = self.stackview.getStack() + my_trans_stack, params = self.stackview.getCurrentView() + + # get stack returns the transposed data, depending on the perspective + self.assertEqual(my_trans_stack.shape, + (self.mystack.shape[1], self.mystack.shape[0], self.mystack.shape[2])) + self.assertTrue(numpy.array_equal(numpy.transpose(self.mystack, axes=(1, 0, 2)), + my_trans_stack)) + + def testSetStackListOfImages(self): + loi = [self.mystack[i] for i in range(self.mystack.shape[0])] + + self.stackview.setStack(loi) + my_orig_stack, params = self.stackview.getStack(returnNumpyArray=True) + my_trans_stack, params = self.stackview.getStack(returnNumpyArray=True) + self.assertEqual(my_trans_stack.shape, self.mystack.shape) + self.assertTrue(numpy.array_equal(self.mystack, + my_trans_stack)) + self.assertTrue(numpy.array_equal(self.mystack, + my_orig_stack)) + self.assertIsInstance(my_trans_stack, numpy.ndarray) + + self.stackview.setStack(loi, perspective=2) + my_orig_stack, params = self.stackview.getStack(copy=False) + my_trans_stack, params = self.stackview.getCurrentView(copy=False) + # getStack(copy=False) must return the object set in setStack + self.assertIs(my_orig_stack, loi) + # getCurrentView(copy=False) returns a ListOfImages whose .images + # attr is the original data + self.assertEqual(my_trans_stack.shape, + (self.mystack.shape[2], self.mystack.shape[0], self.mystack.shape[1])) + self.assertTrue(numpy.array_equal(numpy.array(my_trans_stack), + numpy.transpose(self.mystack, axes=(2, 0, 1)))) + self.assertIsInstance(my_trans_stack, + ListOfImages) # returnNumpyArray=False by default in getStack + self.assertIs(my_trans_stack.images, loi) + + def testPerspective(self): + self.stackview.setStack(numpy.arange(24).reshape((2, 3, 4))) + self.assertEqual(self.stackview._perspective, 0, + "Default perspective is not 0 (dim1-dim2).") + + self.stackview._StackView__planeSelection.setPerspective(1) + self.assertEqual(self.stackview._perspective, 1, + "Plane selection combobox not updating perspective") + + self.stackview.setStack(numpy.arange(6).reshape((1, 2, 3))) + self.assertEqual(self.stackview._perspective, 0, + "Default perspective not restored in setStack.") + + self.stackview.setStack(numpy.arange(24).reshape((2, 3, 4)), perspective=2) + self.assertEqual(self.stackview._perspective, 2, + "Perspective not set in setStack(..., perspective=2).") + + def testTitle(self): + """Test that the plot title contains the proper Z information""" + self.stackview.setStack(numpy.arange(24).reshape((4, 3, 2)), + calibrations=[(0, 1), (-10, 10), (3.14, 3.14)]) + self.assertEqual(self.stackview._plot.getGraphTitle(), + "Image z=0") + self.stackview.setFrameNumber(2) + self.assertEqual(self.stackview._plot.getGraphTitle(), + "Image z=2") + + self.stackview._StackView__planeSelection.setPerspective(1) + self.stackview.setFrameNumber(0) + self.assertEqual(self.stackview._plot.getGraphTitle(), + "Image z=-10") + self.stackview.setFrameNumber(2) + self.assertEqual(self.stackview._plot.getGraphTitle(), + "Image z=10") + + self.stackview._StackView__planeSelection.setPerspective(2) + self.stackview.setFrameNumber(0) + self.assertEqual(self.stackview._plot.getGraphTitle(), + "Image z=3.14") + self.stackview.setFrameNumber(1) + self.assertEqual(self.stackview._plot.getGraphTitle(), + "Image z=6.28") + + +class TestStackViewMainWindow(TestCaseQt): + """Base class for tests of StackView.""" + + def setUp(self): + super(TestStackViewMainWindow, self).setUp() + self.stackview = StackViewMainWindow() + self.stackview.show() + self.qWaitForWindowExposed(self.stackview) + self.mystack = numpy.fromfunction( + lambda i, j, k: numpy.sin(i/15.) + numpy.cos(j/4.) + 2 * numpy.sin(k/6.), + (10, 20, 30) + ) + + def tearDown(self): + self.stackview.setAttribute(qt.Qt.WA_DeleteOnClose) + self.stackview.close() + del self.stackview + super(TestStackViewMainWindow, self).tearDown() + + def testSetStack(self): + self.stackview.setStack(self.mystack) + self.stackview.setColormap("viridis", autoscale=True) + my_trans_stack, params = self.stackview.getStack() + self.assertEqual(my_trans_stack.shape, self.mystack.shape) + self.assertTrue(numpy.array_equal(self.mystack, + my_trans_stack)) + self.assertEqual(params["colormap"]["name"], + "viridis") + + def testSetStackPerspective(self): + self.stackview.setStack(self.mystack, perspective=1) + my_trans_stack, params = self.stackview.getCurrentView() + # get stack returns the transposed data, depending on the perspective + self.assertEqual(my_trans_stack.shape, + (self.mystack.shape[1], self.mystack.shape[0], self.mystack.shape[2])) + self.assertTrue(numpy.array_equal(numpy.transpose(self.mystack, axes=(1, 0, 2)), + my_trans_stack)) + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestStackView)) + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestStackViewMainWindow)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot3d/Plot3DActions.py b/silx/gui/plot3d/Plot3DActions.py new file mode 100644 index 0000000..2ae2750 --- /dev/null +++ b/silx/gui/plot3d/Plot3DActions.py @@ -0,0 +1,362 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides QAction that can be attached to a plot3DWidget.""" + +from __future__ import absolute_import, division + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "26/01/2017" + + +import logging +import os +import weakref + +import numpy + +from silx.gui import qt +from silx.gui.plot.PlotActions import PrintAction as _PrintAction +from silx.gui.icons import getQIcon +from .utils import mng +from .._utils import convertQImageToArray + + +_logger = logging.getLogger(__name__) + + +class Plot3DAction(qt.QAction): + """QAction associated to a Plot3DWidget + + :param parent: See :class:`QAction` + :param Plot3DWidget plot3d: Plot3DWidget the action is associated with + """ + + def __init__(self, parent, plot3d=None): + super(Plot3DAction, self).__init__(parent) + self._plot3d = None + self.setPlot3DWidget(plot3d) + + def setPlot3DWidget(self, widget): + """Set the Plot3DWidget this action is associated with + + :param Plot3DWidget widget: The Plot3DWidget to use + """ + self._plot3d = None if widget is None else weakref.ref(widget) + + def getPlot3DWidget(self): + """Return the Plot3DWidget associated to this action. + + If no widget is associated, it returns None. + + :rtype: qt.QWidget + """ + return None if self._plot3d is None else self._plot3d() + + +class CopyAction(Plot3DAction): + """QAction to provide copy of a Plot3DWidget + + :param parent: See :class:`QAction` + :param Plot3DWidget plot3d: Plot3DWidget the action is associated with + """ + + def __init__(self, parent, plot3d=None): + super(CopyAction, self).__init__(parent, plot3d) + + self.setIcon(getQIcon('edit-copy')) + self.setText('Copy') + self.setToolTip('Copy a snapshot of the 3D scene to the clipboard') + self.setCheckable(False) + self.setShortcut(qt.QKeySequence.Copy) + self.setShortcutContext(qt.Qt.WidgetShortcut) + self.triggered[bool].connect(self._triggered) + + def _triggered(self, checked=False): + plot3d = self.getPlot3DWidget() + if plot3d is None: + _logger.error('Cannot copy widget, no associated Plot3DWidget') + else: + image = plot3d.grabGL() + qt.QApplication.clipboard().setImage(image) + + +class SaveAction(Plot3DAction): + """QAction to provide save snapshot of a Plot3DWidget + + :param parent: See :class:`QAction` + :param Plot3DWidget plot3d: Plot3DWidget the action is associated with + """ + + def __init__(self, parent, plot3d=None): + super(SaveAction, self).__init__(parent, plot3d) + + self.setIcon(getQIcon('document-save')) + self.setText('Save...') + self.setToolTip('Save a snapshot of the 3D scene') + self.setCheckable(False) + self.setShortcut(qt.QKeySequence.Save) + self.setShortcutContext(qt.Qt.WidgetShortcut) + self.triggered[bool].connect(self._triggered) + + def _triggered(self, checked=False): + plot3d = self.getPlot3DWidget() + if plot3d is None: + _logger.error('Cannot save widget, no associated Plot3DWidget') + else: + dialog = qt.QFileDialog(self.parent()) + dialog.setWindowTitle('Save snapshot as') + dialog.setModal(True) + dialog.setNameFilters(('Plot3D Snapshot PNG (*.png)', + 'Plot3D Snapshot JPEG (*.jpg)')) + + dialog.setFileMode(qt.QFileDialog.AnyFile) + dialog.setAcceptMode(qt.QFileDialog.AcceptSave) + + if not dialog.exec_(): + return + + nameFilter = dialog.selectedNameFilter() + filename = dialog.selectedFiles()[0] + dialog.close() + + # Forces the filename extension to match the chosen filter + extension = nameFilter.split()[-1][2:-1] + if (len(filename) <= len(extension) or + filename[-len(extension):].lower() != extension.lower()): + filename += extension + + image = plot3d.grabGL() + if not image.save(filename): + _logger.error('Failed to save image as %s', filename) + qt.QMessageBox.critical( + self.parent(), + 'Save snapshot as', + 'Failed to save snapshot') + + +class PrintAction(Plot3DAction): + """QAction to provide printing of a Plot3DWidget + + :param parent: See :class:`QAction` + :param Plot3DWidget plot3d: Plot3DWidget the action is associated with + """ + + def __init__(self, parent, plot3d=None): + super(PrintAction, self).__init__(parent, plot3d) + + self.setIcon(getQIcon('document-print')) + self.setText('Print...') + self.setToolTip('Print a snapshot of the 3D scene') + self.setCheckable(False) + self.setShortcut(qt.QKeySequence.Print) + self.setShortcutContext(qt.Qt.WidgetShortcut) + self.triggered[bool].connect(self._triggered) + + def getPrinter(self): + """Return the QPrinter instance used for printing. + + :rtype: qt.QPrinter + """ + # TODO This is a hack to sync with silx plot PrintAction + # This needs to be centralized + if _PrintAction._printer is None: + _PrintAction._printer = qt.QPrinter() + return _PrintAction._printer + + def _triggered(self, checked=False): + plot3d = self.getPlot3DWidget() + if plot3d is None: + _logger.error('Cannot print widget, no associated Plot3DWidget') + else: + printer = self.getPrinter() + dialog = qt.QPrintDialog(printer, plot3d) + dialog.setWindowTitle('Print Plot3D snapshot') + if not dialog.exec_(): + return + + image = plot3d.grabGL() + + # Draw pixmap with painter + painter = qt.QPainter() + if not painter.begin(printer): + return + + if (printer.pageRect().width() < image.width() or + printer.pageRect().height() < image.height()): + # Downscale to page + xScale = printer.pageRect().width() / image.width() + yScale = printer.pageRect().height() / image.height() + scale = min(xScale, yScale) + else: + scale = 1. + + rect = qt.QRectF(0, + 0, + scale * image.width(), + scale * image.height()) + painter.drawImage(rect, image) + painter.end() + + +class VideoAction(Plot3DAction): + """This action triggers the recording of a video of the scene. + + The scene is rotated 360 degrees around a vertical axis. + + :param parent: Action parent see :class:`QAction`. + """ + + PNG_SERIE_FILTER = 'Serie of PNG files (*.png)' + MNG_FILTER = 'Multiple-image Network Graphics file (*.mng)' + + def __init__(self, parent, plot3d=None): + super(VideoAction, self).__init__(parent, plot3d) + self.setText('Record video..') + self.setIcon(getQIcon('camera')) + self.setToolTip( + 'Record a video of a 360 degrees rotation of the 3D scene.') + self.setCheckable(False) + self.triggered[bool].connect(self._triggered) + + def _triggered(self, checked=False): + """Action triggered callback""" + plot3d = self.getPlot3DWidget() + if plot3d is None: + _logger.warning( + 'Ignoring action triggered without Plot3DWidget set') + return + + dialog = qt.QFileDialog(parent=plot3d) + dialog.setWindowTitle('Save video as...') + dialog.setModal(True) + dialog.setNameFilters([self.PNG_SERIE_FILTER, + self.MNG_FILTER]) + dialog.setFileMode(dialog.AnyFile) + dialog.setAcceptMode(dialog.AcceptSave) + + if not dialog.exec_(): + return + + nameFilter = dialog.selectedNameFilter() + filename = dialog.selectedFiles()[0] + + # Forces the filename extension to match the chosen filter + extension = nameFilter.split()[-1][2:-1] + if (len(filename) <= len(extension) or + filename[-len(extension):].lower() != extension.lower()): + filename += extension + + nbFrames = int(4. * 25) # 4 seconds, 25 fps + + if nameFilter == self.PNG_SERIE_FILTER: + self._saveAsPNGSerie(filename, nbFrames) + elif nameFilter == self.MNG_FILTER: + self._saveAsMNG(filename, nbFrames) + else: + _logger.error('Unsupported file filter: %s', nameFilter) + + def _saveAsPNGSerie(self, filename, nbFrames): + """Save video as serie of PNG files. + + It adds a counter to the provided filename before the extension. + + :param str filename: filename to use as template + :param int nbFrames: Number of frames to generate + """ + plot3d = self.getPlot3DWidget() + assert plot3d is not None + + # Define filename template + nbDigits = int(numpy.log10(nbFrames)) + 1 + indexFormat = '%%0%dd' % nbDigits + extensionIndex = filename.rfind('.') + filenameFormat = \ + filename[:extensionIndex] + indexFormat + filename[extensionIndex:] + + try: + for index, image in enumerate(self._video360(nbFrames)): + image.save(filenameFormat % index) + except GeneratorExit: + pass + + def _saveAsMNG(self, filename, nbFrames): + """Save video as MNG file. + + :param str filename: filename to use + :param int nbFrames: Number of frames to generate + """ + plot3d = self.getPlot3DWidget() + assert plot3d is not None + + frames = (convertQImageToArray(im) for im in self._video360(nbFrames)) + try: + with open(filename, 'wb') as file_: + for chunk in mng.convert(frames, nb_images=nbFrames): + file_.write(chunk) + except GeneratorExit: + os.remove(filename) # Saving aborted, delete file + + def _video360(self, nbFrames): + """Run the video and provides the images + + :param int nbFrames: The number of frames to generate for + :return: Iterator of QImage of the video sequence + """ + plot3d = self.getPlot3DWidget() + assert plot3d is not None + + angleStep = 360. / nbFrames + + # Create progress bar dialog + dialog = qt.QDialog(plot3d) + dialog.setWindowTitle('Record Video') + layout = qt.QVBoxLayout(dialog) + progress = qt.QProgressBar() + progress.setRange(0, nbFrames) + layout.addWidget(progress) + + btnBox = qt.QDialogButtonBox(qt.QDialogButtonBox.Abort) + btnBox.rejected.connect(dialog.reject) + layout.addWidget(btnBox) + + dialog.setModal(True) + dialog.show() + + qapp = qt.QApplication.instance() + + for frame in range(nbFrames): + progress.setValue(frame) + image = plot3d.grabGL() + yield image + plot3d.viewport.orbitCamera('left', angleStep) + qapp.processEvents() + if not dialog.isVisible(): + break # It as been rejected by the abort button + else: + dialog.accept() + + if dialog.result() == qt.QDialog.Rejected: + raise GeneratorExit('Aborted') diff --git a/silx/gui/plot3d/Plot3DToolBar.py b/silx/gui/plot3d/Plot3DToolBar.py new file mode 100644 index 0000000..cf11362 --- /dev/null +++ b/silx/gui/plot3d/Plot3DToolBar.py @@ -0,0 +1,119 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a toolbar with tools for a Plot3DWidget. + +It provides: + +- Copy +- Save +- Print +""" + +from __future__ import absolute_import + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "10/01/2017" + +import logging + +from silx.gui import qt + +from . import Plot3DActions + +_logger = logging.getLogger(__name__) + + +class Plot3DToolBar(qt.QToolBar): + """Toolbar providing icons to copy, save and print the OpenGL scene + + :param parent: See :class:`QWidget` + :param str title: Title of the toolbar. + """ + + def __init__(self, parent=None, title='Plot3D'): + super(Plot3DToolBar, self).__init__(title, parent) + + self._plot3d = None + + self._copyAction = Plot3DActions.CopyAction(parent=self) + self.addAction(self._copyAction) + + self._saveAction = Plot3DActions.SaveAction(parent=self) + self.addAction(self._saveAction) + + self._videoAction = Plot3DActions.VideoAction(parent=self) + self.addAction(self._videoAction) + + self._printAction = Plot3DActions.PrintAction(parent=self) + self.addAction(self._printAction) + + def setPlot3DWidget(self, widget): + """Set the Plot3DWidget this toolbar is associated with + + :param Plot3DWidget widget: The widget to copy/save/print + """ + self._plot3d = widget + self.getCopyAction().setPlot3DWidget(widget) + self.getSaveAction().setPlot3DWidget(widget) + self.getVideoRecordAction().setPlot3DWidget(widget) + self.getPrintAction().setPlot3DWidget(widget) + + def getPlot3DWidget(self): + """Return the Plot3DWidget associated to this toolbar. + + If no widget is associated, it returns None. + + :rtype: qt.QWidget + """ + return self._plot3d + + def getCopyAction(self): + """Returns the QAction performing copy to clipboard of the Plot3DWidget + + :rtype: qt.QAction + """ + return self._copyAction + + def getSaveAction(self): + """Returns the QAction performing save to file of the Plot3DWidget + + :rtype: qt.QAction + """ + return self._saveAction + + def getVideoRecordAction(self): + """Returns the QAction performing record video of the Plot3DWidget + + :rtype: qt.QAction + """ + return self._videoAction + + def getPrintAction(self): + """Returns the QAction performing printing of the Plot3DWidget + + :rtype: qt.QAction + """ + return self._printAction diff --git a/silx/gui/plot3d/Plot3DWidget.py b/silx/gui/plot3d/Plot3DWidget.py new file mode 100644 index 0000000..9c9da0c --- /dev/null +++ b/silx/gui/plot3d/Plot3DWidget.py @@ -0,0 +1,341 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a Qt widget embedding an OpenGL scene.""" + +from __future__ import absolute_import + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "26/01/2017" + + +import logging + +from silx.gui import qt +from silx.gui.plot.Colors import rgba +from silx.gui.plot3d import Plot3DActions +from .._utils import convertArrayToQImage + +from .._glutils import gl +from .scene import interaction, primitives, transform +from . import scene + +import numpy + + +_logger = logging.getLogger(__name__) + + +class _OverviewViewport(scene.Viewport): + """A scene displaying the orientation of the data in another scene. + + :param Camera camera: The camera to track. + """ + + def __init__(self, camera=None): + super(_OverviewViewport, self).__init__() + self.size = 100, 100 + + self.scene.transforms = [transform.Scale(2.5, 2.5, 2.5)] + + axes = primitives.Axes() + self.scene.children.append(axes) + + if camera is not None: + camera.addListener(self._cameraChanged) + + def _cameraChanged(self, source): + """Listen to camera in other scene for transformation updates. + + Sync the overview camera to point in the same direction + but from a sphere centered on origin. + """ + position = -12. * source.extrinsic.direction + self.camera.extrinsic.position = position + + self.camera.extrinsic.setOrientation( + source.extrinsic.direction, source.extrinsic.up) + + +class Plot3DWidget(qt.QGLWidget): + """QGLWidget with a 3D viewport and an overview.""" + + def __init__(self, parent=None): + if not qt.QGLFormat.hasOpenGL(): # Check if any OpenGL is available + raise RuntimeError( + 'OpenGL is not available on this platform: 3D disabled') + + self._devicePixelRatio = 1.0 # Store GL canvas/QWidget ratio + self._isOpenGL21 = False + self._firstRender = True + + format_ = qt.QGLFormat() + format_.setRgba(True) + format_.setDepth(False) + format_.setStencil(False) + format_.setVersion(2, 1) + format_.setDoubleBuffer(True) + + super(Plot3DWidget, self).__init__(format_, parent) + self.setAutoFillBackground(False) + self.setMouseTracking(True) + + self.setFocusPolicy(qt.Qt.StrongFocus) + self._copyAction = Plot3DActions.CopyAction(parent=self, plot3d=self) + self.addAction(self._copyAction) + + self._updating = False # True if an update is requested + + # Main viewport + self.viewport = scene.Viewport() + self.viewport.background = 0.2, 0.2, 0.2, 1. + + sceneScale = transform.Scale(1., 1., 1.) + self.viewport.scene.transforms = [sceneScale, + transform.Translate(0., 0., 0.)] + + # Overview area + self.overview = _OverviewViewport(self.viewport.camera) + + self.setBackgroundColor((0.2, 0.2, 0.2, 1.)) + + # Window describing on screen area to render + self.window = scene.Window(mode='framebuffer') + self.window.viewports = [self.viewport, self.overview] + + self.eventHandler = interaction.CameraControl( + self.viewport, orbitAroundCenter=False, + mode='position', scaleTransform=sceneScale, + selectCB=None) + + self.viewport.addListener(self._redraw) + + def setProjection(self, projection): + """Change the projection in use. + + :param str projection: In 'perspective', 'orthographic'. + """ + if projection == 'orthographic': + projection = transform.Orthographic(size=self.viewport.size) + elif projection == 'perspective': + projection = transform.Perspective(fovy=30., + size=self.viewport.size) + else: + raise RuntimeError('Unsupported projection: %s' % projection) + + self.viewport.camera.intrinsic = projection + self.viewport.resetCamera() + + def getProjection(self): + """Return the current camera projection mode as a str. + + See :meth:`setProjection` + """ + projection = self.viewport.camera.intrinsic + if isinstance(projection, transform.Orthographic): + return 'orthographic' + elif isinstance(projection, transform.Perspective): + return 'perspective' + else: + raise RuntimeError('Unknown projection in use') + + def setBackgroundColor(self, color): + """Set the background color of the OpenGL view. + + :param color: RGB color of the isosurface: name, #RRGGBB or RGB values + :type color: + QColor, str or array-like of 3 or 4 float in [0., 1.] or uint8 + """ + color = rgba(color) + self.viewport.background = color + self.overview.background = color[0]*0.5, color[1]*0.5, color[2]*0.5, 1. + + def getBackgroundColor(self): + """Returns the RGBA background color (QColor).""" + return qt.QColor.fromRgbF(*self.viewport.background) + + def centerScene(self): + """Position the center of the scene at the center of rotation.""" + self.viewport.resetCamera() + + def resetZoom(self, face='front'): + """Reset the camera position to a default. + + :param str face: The direction the camera is looking at: + side, front, back, top, bottom, right, left. + Default: front. + """ + self.viewport.camera.extrinsic.reset(face=face) + self.centerScene() + + def _redraw(self, source=None): + """Viewport listener to require repaint""" + if not self._updating and self.viewport.dirty: + self._updating = True # Mark that an update is requested + self.update() # Queued repaint (i.e., asynchronous) + + def sizeHint(self): + return qt.QSize(400, 300) + + def initializeGL(self): + # Check if OpenGL2 is available + versionflags = self.format().openGLVersionFlags() + self._isOpenGL21 = bool(versionflags & qt.QGLFormat.OpenGL_Version_2_1) + if not self._isOpenGL21: + _logger.error( + '3D rendering is disabled: OpenGL 2.1 not available') + + messageBox = qt.QMessageBox(parent=self) + messageBox.setIcon(qt.QMessageBox.Critical) + messageBox.setWindowTitle('Error') + messageBox.setText('3D rendering is disabled.\n\n' + 'Reason: OpenGL 2.1 is not available.') + messageBox.addButton(qt.QMessageBox.Ok) + messageBox.setWindowModality(qt.Qt.WindowModal) + messageBox.setAttribute(qt.Qt.WA_DeleteOnClose) + messageBox.show() + + def paintGL(self): + # In case paintGL is called by the system and not through _redraw, + # Mark as updating. + self._updating = True + + if hasattr(self, 'windowHandle'): # Qt 5 + devicePixelRatio = self.windowHandle().devicePixelRatio() + if devicePixelRatio != self._devicePixelRatio: + # Move window from one screen to another one + self._devicePixelRatio = devicePixelRatio + # Resize might not be called, so call it explicitly + self.resizeGL(int(self.width() * devicePixelRatio), + int(self.height() * devicePixelRatio)) + + if not self._isOpenGL21: + # Cannot render scene, just clear the color buffer. + ox, oy = self.viewport.origin + w, h = self.viewport.size + gl.glViewport(ox, oy, w, h) + + gl.glClearColor(*self.viewport.background) + gl.glClear(gl.GL_COLOR_BUFFER_BIT) + + else: + # Update near and far planes only if viewport needs refresh + if self.viewport.dirty: + self.viewport.adjustCameraDepthExtent() + + self.window.render(self.context(), self._devicePixelRatio) + + if self._firstRender: # TODO remove this ugly hack + self._firstRender = False + self.centerScene() + self._updating = False + + def resizeGL(self, width, height): + self.window.size = width, height + self.viewport.size = width, height + overviewWidth, overviewHeight = self.overview.size + self.overview.origin = width - overviewWidth, height - overviewHeight + + def grabGL(self): + """Renders the OpenGL scene into a numpy array + + :returns: OpenGL scene RGB rasterization + :rtype: QImage + """ + if not self._isOpenGL21: + _logger.error('OpenGL 2.1 not available, cannot save OpenGL image') + height, width = self.window.shape + image = numpy.zeros((height, width, 3), dtype=numpy.uint8) + + else: + self.makeCurrent() + image = self.window.grab(qt.QGLContext.currentContext()) + + return convertArrayToQImage(image) + + def wheelEvent(self, event): + xpixel = event.x() * self._devicePixelRatio + ypixel = event.y() * self._devicePixelRatio + if hasattr(event, 'delta'): # Qt4 + angle = event.delta() / 8. + else: # Qt5 + angle = event.angleDelta().y() / 8. + event.accept() + + if angle != 0: + self.makeCurrent() + self.eventHandler.handleEvent('wheel', xpixel, ypixel, angle) + + def keyPressEvent(self, event): + keycode = event.key() + # No need to accept QKeyEvent + + converter = { + qt.Qt.Key_Left: 'left', + qt.Qt.Key_Right: 'right', + qt.Qt.Key_Up: 'up', + qt.Qt.Key_Down: 'down' + } + direction = converter.get(keycode, None) + if direction is not None: + if event.modifiers() == qt.Qt.ControlModifier: + self.viewport.camera.rotate(direction) + elif event.modifiers() == qt.Qt.ShiftModifier: + self.viewport.moveCamera(direction) + else: + self.viewport.orbitCamera(direction) + + else: + # Key not handled, call base class implementation + super(Plot3DWidget, self).keyPressEvent(event) + + # Mouse events # + _MOUSE_BTNS = {1: 'left', 2: 'right', 4: 'middle'} + + def mousePressEvent(self, event): + xpixel = event.x() * self._devicePixelRatio + ypixel = event.y() * self._devicePixelRatio + btn = self._MOUSE_BTNS[event.button()] + event.accept() + + self.makeCurrent() + self.eventHandler.handleEvent('press', xpixel, ypixel, btn) + + def mouseMoveEvent(self, event): + xpixel = event.x() * self._devicePixelRatio + ypixel = event.y() * self._devicePixelRatio + event.accept() + + self.makeCurrent() + self.eventHandler.handleEvent('move', xpixel, ypixel) + + def mouseReleaseEvent(self, event): + xpixel = event.x() * self._devicePixelRatio + ypixel = event.y() * self._devicePixelRatio + btn = self._MOUSE_BTNS[event.button()] + event.accept() + + self.makeCurrent() + self.eventHandler.handleEvent('release', xpixel, ypixel, btn) diff --git a/silx/gui/plot3d/Plot3DWindow.py b/silx/gui/plot3d/Plot3DWindow.py new file mode 100644 index 0000000..4658d38 --- /dev/null +++ b/silx/gui/plot3d/Plot3DWindow.py @@ -0,0 +1,94 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a QMainWindow with a 3D scene and associated toolbar. +""" + +from __future__ import absolute_import + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "26/01/2017" + + +from silx.gui import qt + +from .Plot3DToolBar import Plot3DToolBar +from .Plot3DWidget import Plot3DWidget +from .ViewpointToolBar import ViewpointToolBar + + +class Plot3DWindow(qt.QMainWindow): + """QGLWidget with a 3D viewport and an overview.""" + + def __init__(self, parent=None): + super(Plot3DWindow, self).__init__(parent) + if parent is not None: + # behave as a widget + self.setWindowFlags(qt.Qt.Widget) + + self._plot3D = Plot3DWidget() + self.setCentralWidget(self._plot3D) + self.addToolBar( + ViewpointToolBar(parent=self, plot3D=self._plot3D)) + toolbar = Plot3DToolBar(parent=self) + toolbar.setPlot3DWidget(self._plot3D) + self.addToolBar(toolbar) + self.addActions(toolbar.actions()) + + def getPlot3DWidget(self): + """Get the :class:`Plot3DWidget` of this window""" + return self._plot3D + + # Proxy to Plot3DWidget + + def setProjection(self, projection): + return self._plot3D.setProjection(projection) + + setProjection.__doc__ = Plot3DWidget.setProjection.__doc__ + + def getProjection(self): + return self._plot3D.getProjection() + + getProjection.__doc__ = Plot3DWidget.getProjection.__doc__ + + def centerScene(self): + return self._plot3D.centerScene() + + centerScene.__doc__ = Plot3DWidget.centerScene.__doc__ + + def resetZoom(self): + return self._plot3D.resetZoom() + + resetZoom.__doc__ = Plot3DWidget.resetZoom.__doc__ + + def getBackgroundColor(self): + return self._plot3D.getBackgroundColor() + + getBackgroundColor.__doc__ = Plot3DWidget.getBackgroundColor.__doc__ + + def setBackgroundColor(self, color): + return self._plot3D.setBackgroundColor(color) + + setBackgroundColor.__doc__ = Plot3DWidget.setBackgroundColor.__doc__ diff --git a/silx/gui/plot3d/SFViewParamTree.py b/silx/gui/plot3d/SFViewParamTree.py new file mode 100644 index 0000000..38d4e37 --- /dev/null +++ b/silx/gui/plot3d/SFViewParamTree.py @@ -0,0 +1,1467 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +This module provides a tree widget to set/view parameters of a ScalarFieldView. +""" + +from __future__ import absolute_import + +__authors__ = ["D. N."] +__license__ = "MIT" +__date__ = "10/01/2017" + +import logging +import sys + +import numpy + +from silx.gui import qt +from silx.gui.icons import getQIcon + +from .ScalarFieldView import Isosurface + + +_logger = logging.getLogger(__name__) + + +class ModelColumns(object): + NameColumn, ValueColumn, ColumnMax = range(3) + ColumnNames = ['Name', 'Value'] + + +class SubjectItem(qt.QStandardItem): + """ + Base class for observers items. + + Subclassing: + ------------ + The following method can/should be reimplemented: + - _init + - _pullData + - _pushData + - _setModelData + - _subjectChanged + - getEditor + - getSignals + - leftClicked + - queryRemove + - setEditorData + + Also the following attributes are available: + - editable + - persistent + + :param subject: object that this item will be observing. + """ + + editable = False + """ boolean: set to True to make the item editable. """ + + persistent = False + """ + boolean: set to True to make the editor persistent. + See : Qt.QAbstractItemView.openPersistentEditor + """ + + def __init__(self, subject, *args): + + super(SubjectItem, self).__init__(*args) + + self.setEditable(self.editable) + + self.__subject = None + self.subject = subject + + def setData(self, value, role=qt.Qt.UserRole, pushData=True): + """ + Overloaded method from QStandardItem. The pushData keyword tells + the item to push data to the subject if the role is equal to EditRole. + This is useful to let this method know if the setData method was called + internaly or from the view. + + :param value: the value ti set to data + :param role: role in the item + :param pushData: if True push value in the existing data. + """ + if role == qt.Qt.EditRole and pushData: + setValue = self._pushData(value, role) + if setValue != value: + value = setValue + super(SubjectItem, self).setData(value, role) + + subject = property(lambda self: self.__subject) + + @subject.setter + def subject(self, subject): + if self.__subject is not None: + raise ValueError('Subject already set ' + ' (subject change not supported).') + self.__subject = subject + if subject is not None: + self._init() + self._connectSignals() + + def _connectSignals(self): + """ + Connects the signals. Called when the subject is set. + """ + + def gen_slot(_sigIdx): + def slotfn(*args, **kwargs): + self._subjectChanged(signalIdx=_sigIdx, + args=args, + kwargs=kwargs) + return slotfn + + if self.__subject is not None: + self.__slots = slots = [] + + signals = self.getSignals() + + if signals: + if not isinstance(signals, (list, tuple)): + signals = [signals] + for sigIdx, signal in enumerate(signals): + slot = gen_slot(sigIdx) + signal.connect(slot) + slots.append((signal, slot)) + + def _disconnectSignals(self): + """ + Disconnects all subject's signal + """ + if self.__slots: + for signal, slot in self.__slots: + try: + signal.disconnect(slot) + except TypeError: + pass + + def _enableRow(self, enable): + """ + Set the enabled state for this cell, or for the whole row + if this item has a parent. + + :param bool enable: True if we wan't to enable the cell + """ + parent = self.parent() + model = self.model() + if model is None or parent is None: + # no parent -> no siblings + self.setEnabled(enable) + return + + for col in range(model.columnCount()): + sibling = parent.child(self.row(), col) + sibling.setEnabled(enable) + + ################################################################# + # Overloadable methods + ################################################################# + + def getSignals(self): + """ + Returns the list of this items subject's signals that + this item will be listening to. + + :return: list. + """ + return None + + def _subjectChanged(self, signalIdx=None, args=None, kwargs=None): + """ + Called when one of the signals is triggered. Default implementation + just calls _pullData, compares the result to the current value stored + as Qt.EditRole, and stores the new value if it is different. It also + stores its str representation as Qt.DisplayRole + + :param signalIdx: index of the triggered signal. The value passed + is the same as the signal position in the list returned by + SubjectItem.getSignals. + :param args: arguments received from the signal + :param kwargs: keyword arguments received from the signal + """ + data = self._pullData() + if data == self.data(qt.Qt.EditRole): + return + self.setData(data, role=qt.Qt.DisplayRole, pushData=False) + self.setData(data, role=qt.Qt.EditRole, pushData=False) + + def _pullData(self): + """ + Pulls data from the subject. + + :return: subject data + """ + return None + + def _pushData(self, value, role=qt.Qt.UserRole): + """ + Pushes data to the subject and returns the actual value that was stored + + :return: the value that was stored + """ + return value + + def _init(self): + """ + Called when the subject is set. + :return: + """ + self._subjectChanged() + + def getEditor(self, parent, option, index): + """ + Returns the editor widget used to edit this item's data. The arguments + are the one passed to the QStyledItemDelegate.createEditor method. + + :param parent: the Qt parent of the editor + :param option: + :param index: + :return: + """ + return None + + def setEditorData(self, editor): + """ + This is called by the View's delegate just before the editor is shown, + its purpose it to setup the editors contents. Return False to use + the delegate's default behaviour. + + :param editor: + :return: + """ + return True + + def _setModelData(self, editor): + """ + This is called by the View's delegate just before the editor is closed, + its allows this item to update itself with data from the editor. + + :param editor: + :return: + """ + return False + + def queryRemove(self, view=None): + """ + This is called by the view to ask this items if it (the view) can + remove it. Return True to let the view know that the item can be + removed. + + :param view: + :return: + """ + return False + + def leftClicked(self): + """ + This method is called by the view when the item's cell if left clicked. + + :return: + """ + pass + + +# View settings ############################################################### + +class ColorItem(SubjectItem): + """color item.""" + editable = True + persistent = True + + def getEditor(self, parent, option, index): + editor = QColorEditor(parent) + editor.color = self.getColor() + editor.sigColorChanged.connect(self._editorSlot) + return editor + + def _editorSlot(self, color): + self.setData(color, qt.Qt.EditRole) + + def _pushData(self, value, role=qt.Qt.UserRole): + self.setColor(value) + return self.getColor() + + def _pullData(self): + self.getColor() + + def setColor(self, color): + """Override to implement actual color setter""" + pass + + +class BackgroundColorItem(ColorItem): + itemName = 'Background' + + def setColor(self, color): + self.subject.setBackgroundColor(color) + + def getColor(self): + return self.subject.getBackgroundColor() + + +class ForegroundColorItem(ColorItem): + itemName = 'Foreground' + + def setColor(self, color): + self.subject.setForegroundColor(color) + + def getColor(self): + return self.subject.getForegroundColor() + + +class HighlightColorItem(ColorItem): + itemName = 'Highlight' + + def setColor(self, color): + self.subject.setHighlightColor(color) + + def getColor(self): + return self.subject.getHighlightColor() + + +class ViewSettingsItem(qt.QStandardItem): + """Viewport settings""" + + def __init__(self, subject, *args): + + super(ViewSettingsItem, self).__init__(*args) + + self.setEditable(False) + + classes = BackgroundColorItem, ForegroundColorItem, HighlightColorItem + for cls in classes: + titleItem = qt.QStandardItem(cls.itemName) + titleItem.setEditable(False) + self.appendRow([titleItem, cls(subject)]) + + +# Data information ############################################################ + +class DataChangedItem(SubjectItem): + """ + Base class for items listening to ScalarFieldView.sigDataChanged + """ + + def getSignals(self): + subject = self.subject + if subject: + return subject.sigDataChanged + return None + + def _init(self): + self._subjectChanged() + + +class DataTypeItem(DataChangedItem): + itemName = 'dtype' + + def _pullData(self): + data = self.subject.getData(copy=False) + return ((data is not None) and str(data.dtype)) or 'N/A' + + +class DataShapeItem(DataChangedItem): + itemName = 'size' + + def _pullData(self): + data = self.subject.getData(copy=False) + if data is None: + return 'N/A' + else: + return str(list(reversed(data.shape))) + + +class OffsetItem(DataChangedItem): + itemName = 'offset' + + def _pullData(self): + offset = self.subject.getTranslation() + return ((offset is not None) and str(offset)) or 'N/A' + + +class ScaleItem(DataChangedItem): + itemName = 'scale' + + def _pullData(self): + scale = self.subject.getScale() + return ((scale is not None) and str(scale)) or 'N/A' + + +class DataSetItem(qt.QStandardItem): + + def __init__(self, subject, *args): + + super(DataSetItem, self).__init__(*args) + + self.setEditable(False) + + klasses = [DataTypeItem, DataShapeItem, OffsetItem, ScaleItem] + for klass in klasses: + titleItem = qt.QStandardItem(klass.itemName) + titleItem.setEditable(False) + self.appendRow([titleItem, klass(subject)]) + + +# Isosurface ################################################################## + +class IsoSurfaceRootItem(SubjectItem): + """ + Root (i.e : column index 0) Isosurface item. + """ + + def getSignals(self): + subject = self.subject + return [subject.sigColorChanged, + subject.sigVisibilityChanged] + + def _subjectChanged(self, signalIdx=None, args=None, kwargs=None): + if signalIdx == 0: + color = self.subject.getColor() + self.setData(color, qt.Qt.DecorationRole) + elif signalIdx == 1: + visible = args[0] + self.setCheckState((visible and qt.Qt.Checked) or qt.Qt.Unchecked) + + def _init(self): + self.setCheckable(True) + + isosurface = self.subject + color = isosurface.getColor() + visible = isosurface.isVisible() + self.setData(color, qt.Qt.DecorationRole) + self.setCheckState((visible and qt.Qt.Checked) or qt.Qt.Unchecked) + + nameItem = qt.QStandardItem('Level') + sliderItem = IsoSurfaceLevelSlider(self.subject) + self.appendRow([nameItem, sliderItem]) + + nameItem = qt.QStandardItem('Color') + nameItem.setEditable(False) + valueItem = IsoSurfaceColorItem(self.subject) + self.appendRow([nameItem, valueItem]) + + nameItem = qt.QStandardItem('Opacity') + nameItem.setTextAlignment(qt.Qt.AlignLeft | qt.Qt.AlignTop) + nameItem.setEditable(False) + valueItem = IsoSurfaceAlphaItem(self.subject) + self.appendRow([nameItem, valueItem]) + + nameItem = qt.QStandardItem() + nameItem.setEditable(False) + valueItem = IsoSurfaceAlphaLegendItem(self.subject) + valueItem.setEditable(False) + self.appendRow([nameItem, valueItem]) + + def queryRemove(self, view=None): + buttons = qt.QMessageBox.Ok | qt.QMessageBox.Cancel + ans = qt.QMessageBox.question(view, + 'Remove isosurface', + 'Remove the selected iso-surface?', + buttons=buttons) + if ans == qt.QMessageBox.Ok: + sfview = self.subject.parent() + if sfview: + sfview.removeIsosurface(self.subject) + return False + return False + + def leftClicked(self): + checked = (self.checkState() == qt.Qt.Checked) + visible = self.subject.isVisible() + if checked != visible: + self.subject.setVisible(checked) + + +class IsoSurfaceLevelItem(SubjectItem): + """ + Base class for the isosurface level items. + """ + editable = True + + def getSignals(self): + subject = self.subject + return [subject.sigLevelChanged, + subject.sigVisibilityChanged] + + def setEditorData(self, editor): + return False + + def _pullData(self): + return self.subject.getLevel() + + def _pushData(self, value, role=qt.Qt.UserRole): + self.subject.setLevel(value) + return self.subject.getLevel() + + +class _IsoLevelSlider(qt.QSlider): + """QSlider used for iso-surface level""" + + def __init__(self, parent, subject): + super(_IsoLevelSlider, self).__init__(parent=parent) + self.subject = subject + + self.sliderReleased.connect(self.__sliderReleased) + + self.subject.sigLevelChanged.connect(self.setLevel) + self.subject.parent().sigDataChanged.connect(self.__dataChanged) + + def setLevel(self, level): + """Set slider from iso-surface level""" + dataRange = self.subject.parent().getDataRange() + + if dataRange is not None and None not in dataRange: + width = dataRange[1] - dataRange[0] + if width > 0: + sliderWidth = self.maximum() - self.minimum() + sliderPosition = sliderWidth * (level - dataRange[0]) / width + self.setValue(sliderPosition) + + def __dataChanged(self): + """Handles data update to refresh slider range if needed""" + self.setLevel(self.subject.getLevel()) + + def __sliderReleased(self): + value = self.value() + dataRange = self.subject.parent().getDataRange() + width = dataRange[1] - dataRange[0] + sliderWidth = self.maximum() - self.minimum() + level = dataRange[0] + width * value / sliderWidth + self.subject.setLevel(level) + + +class IsoSurfaceLevelSlider(IsoSurfaceLevelItem): + """ + Isosurface level item with a slider editor. + """ + nTicks = 1000 + persistent = True + + def getEditor(self, parent, option, index): + editor = _IsoLevelSlider(parent, self.subject) + editor.setOrientation(qt.Qt.Horizontal) + editor.setMinimum(0) + editor.setMaximum(self.nTicks) + + editor.setSingleStep(1) + + editor.setLevel(self.subject.getLevel()) + return editor + + def setEditorData(self, editor): + return True + + def _setModelData(self, editor): + return True + + +class IsoSurfaceColorItem(SubjectItem): + """ + Isosurface color item. + """ + editable = True + persistent = True + + def getSignals(self): + return self.subject.sigColorChanged + + def getEditor(self, parent, option, index): + editor = QColorEditor(parent) + color = self.subject.getColor() + color.setAlpha(255) + editor.color = color + editor.sigColorChanged.connect(self.__editorChanged) + return editor + + def __editorChanged(self, color): + color.setAlpha(self.subject.getColor().alpha()) + self.subject.setColor(color) + + def _pushData(self, value, role=qt.Qt.UserRole): + self.subject.setColor(value) + return self.subject.getColor() + + +class QColorEditor(qt.QWidget): + """ + QColor editor. + """ + sigColorChanged = qt.Signal(object) + + color = property(lambda self: qt.QColor(self.__color)) + + @color.setter + def color(self, color): + self._setColor(color) + self.__previousColor = color + + def __init__(self, *args, **kwargs): + super(QColorEditor, self).__init__(*args, **kwargs) + layout = qt.QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + button = qt.QToolButton() + icon = qt.QIcon(qt.QPixmap(32, 32)) + button.setIcon(icon) + layout.addWidget(button) + button.clicked.connect(self.__showColorDialog) + layout.addStretch(1) + + self.__color = None + self.__previousColor = None + + def sizeHint(self): + return qt.QSize(0, 0) + + def _setColor(self, qColor): + button = self.findChild(qt.QToolButton) + pixmap = qt.QPixmap(32, 32) + pixmap.fill(qColor) + button.setIcon(qt.QIcon(pixmap)) + self.__color = qColor + + def __showColorDialog(self): + dialog = qt.QColorDialog(parent=self) + if sys.platform == 'darwin': + # Use of native color dialog on macos might cause problems + dialog.setOption(qt.QColorDialog.DontUseNativeDialog, True) + + self.__previousColor = self.__color + dialog.setAttribute(qt.Qt.WA_DeleteOnClose) + dialog.setModal(True) + dialog.currentColorChanged.connect(self.__colorChanged) + dialog.finished.connect(self.__dialogClosed) + dialog.show() + + def __colorChanged(self, color): + self.__color = color + self._setColor(color) + self.sigColorChanged.emit(color) + + def __dialogClosed(self, result): + if result == qt.QDialog.Rejected: + self.__colorChanged(self.__previousColor) + self.__previousColor = None + + +class IsoSurfaceAlphaItem(SubjectItem): + """ + Isosurface alpha item. + """ + editable = True + persistent = True + + def _init(self): + pass + + def getSignals(self): + return self.subject.sigColorChanged + + def getEditor(self, parent, option, index): + editor = qt.QSlider(parent) + editor.setOrientation(qt.Qt.Horizontal) + editor.setMinimum(0) + editor.setMaximum(255) + + color = self.subject.getColor() + editor.setValue(color.alpha()) + + editor.valueChanged.connect(self.__editorChanged) + + return editor + + def __editorChanged(self, value): + color = self.subject.getColor() + color.setAlpha(value) + self.subject.setColor(color) + + def setEditorData(self, editor): + return True + + def _setModelData(self, editor): + return True + + +class IsoSurfaceAlphaLegendItem(SubjectItem): + """Legend to place under opacity slider""" + + editable = False + persistent = True + + def getEditor(self, parent, option, index): + layout = qt.QHBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + layout.addWidget(qt.QLabel('0')) + layout.addStretch(1) + layout.addWidget(qt.QLabel('1')) + + editor = qt.QWidget(parent) + editor.setLayout(layout) + return editor + + +class IsoSurfaceCount(SubjectItem): + """ + Item displaying the number of isosurfaces. + """ + + def getSignals(self): + subject = self.subject + return [subject.sigIsosurfaceAdded, subject.sigIsosurfaceRemoved] + + def _pullData(self): + return len(self.subject.getIsosurfaces()) + + +class IsoSurfaceAddRemoveWidget(qt.QWidget): + + sigViewTask = qt.Signal(str) + """Signal for the tree view to perform some task""" + + def __init__(self, parent, item): + super(IsoSurfaceAddRemoveWidget, self).__init__(parent) + self._item = item + layout = qt.QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + addBtn = qt.QToolButton() + addBtn.setText('+') + addBtn.setToolButtonStyle(qt.Qt.ToolButtonTextOnly) + layout.addWidget(addBtn) + addBtn.clicked.connect(self.__addClicked) + + removeBtn = qt.QToolButton() + removeBtn.setText('-') + removeBtn.setToolButtonStyle(qt.Qt.ToolButtonTextOnly) + layout.addWidget(removeBtn) + removeBtn.clicked.connect(self.__removeClicked) + + layout.addStretch(1) + + def __addClicked(self): + sfview = self._item.subject + if not sfview: + return + dataRange = sfview.getDataRange() + if dataRange is None: + dataRange = [0, 1] + + sfview.addIsosurface(numpy.mean(dataRange), '#0000FF') + + def __removeClicked(self): + self.sigViewTask.emit('remove_iso') + + +class IsoSurfaceAddRemoveItem(SubjectItem): + """ + Item displaying a simple QToolButton allowing to add an isosurface. + """ + persistent = True + + def getEditor(self, parent, option, index): + return IsoSurfaceAddRemoveWidget(parent, self) + + +class IsoSurfaceGroup(SubjectItem): + """ + Root item for the list of isosurface items. + """ + def getSignals(self): + subject = self.subject + return [subject.sigIsosurfaceAdded, subject.sigIsosurfaceRemoved] + + def _subjectChanged(self, signalIdx=None, args=None, kwargs=None): + if signalIdx == 0: + if len(args) >= 1: + isosurface = args[0] + if not isinstance(isosurface, Isosurface): + raise ValueError('Expected an isosurface instance.') + self.__addIsosurface(isosurface) + else: + raise ValueError('Expected an isosurface instance.') + elif signalIdx == 1: + if len(args) >= 1: + isosurface = args[0] + if not isinstance(isosurface, Isosurface): + raise ValueError('Expected an isosurface instance.') + self.__removeIsosurface(isosurface) + else: + raise ValueError('Expected an isosurface instance.') + + def __addIsosurface(self, isosurface): + valueItem = IsoSurfaceRootItem(subject=isosurface) + nameItem = IsoSurfaceLevelItem(subject=isosurface) + self.insertRow(max(0, self.rowCount() - 1), [valueItem, nameItem]) + + def __removeIsosurface(self, isosurface): + for row in range(self.rowCount()): + child = self.child(row) + subject = getattr(child, 'subject', None) + if subject == isosurface: + self.takeRow(row) + break + + def _init(self): + nameItem = IsoSurfaceAddRemoveItem(self.subject) + valueItem = qt.QStandardItem() + valueItem.setEditable(False) + self.appendRow([nameItem, valueItem]) + + subject = self.subject + isosurfaces = subject.getIsosurfaces() + for isosurface in isosurfaces: + self.__addIsosurface(isosurface) + + +# Cutting Plane ############################################################### + +class ColormapBase(SubjectItem): + """ + Mixin class for colormap items. + """ + + def getSignals(self): + return [self.subject.getCutPlanes()[0].sigColormapChanged] + + +class PlaneMinRangeItem(ColormapBase): + """ + colormap minVal item. + Editor is a QLineEdit with a QDoubleValidator + """ + editable = True + + def _pullData(self): + colormap = self.subject.getCutPlanes()[0].getColormap() + auto = colormap.isAutoscale() + if auto == self.isEnabled(): + self._enableRow(not auto) + return colormap.getVMin() + + def _pushData(self, value, role=qt.Qt.UserRole): + self._setVMin(value) + + def _setVMin(self, value): + cutPlane = self.subject.getCutPlanes()[0] + colormap = cutPlane.getColormap() + vMin = value + vMax = colormap.getVMax() + + if vMax is not None and value > vMax: + vMin = vMax + vMax = value + cutPlane.setColormap(name=colormap.getName(), + norm=colormap.getNorm(), + vmin=vMin, + vmax=vMax) + + def getEditor(self, parent, option, index): + editor = qt.QLineEdit(parent) + editor.setValidator(qt.QDoubleValidator()) + return editor + + def setEditorData(self, editor): + editor.setText(str(self._pullData())) + return True + + def _setModelData(self, editor): + value = float(editor.text()) + self._setVMin(value) + return True + + +class PlaneMaxRangeItem(ColormapBase): + """ + colormap maxVal item. + Editor is a QLineEdit with a QDoubleValidator + """ + editable = True + + def _pullData(self): + colormap = self.subject.getCutPlanes()[0].getColormap() + auto = colormap.isAutoscale() + if auto == self.isEnabled(): + self._enableRow(not auto) + return self.subject.getCutPlanes()[0].getColormap().getVMax() + + def _setVMax(self, value): + cutPlane = self.subject.getCutPlanes()[0] + colormap = cutPlane.getColormap() + vMin = colormap.getVMin() + vMax = value + if vMin is not None and value < vMin: + vMax = vMin + vMin = value + cutPlane.setColormap(name=colormap.getName(), + norm=colormap.getNorm(), + vmin=vMin, + vmax=vMax) + + def getEditor(self, parent, option, index): + editor = qt.QLineEdit(parent) + editor.setValidator(qt.QDoubleValidator()) + return editor + + def setEditorData(self, editor): + editor.setText(str(self._pullData())) + return True + + def _setModelData(self, editor): + value = float(editor.text()) + self._setVMax(value) + return True + + +class PlaneOrientationItem(SubjectItem): + """ + Plane orientation item. + Editor is a QComboBox. + """ + editable = True + + _PLANE_ACTIONS = ( + ('3d-plane-normal-x', 'Plane 0', + 'Set plane perpendicular to red axis', (1., 0., 0.)), + ('3d-plane-normal-y', 'Plane 1', + 'Set plane perpendicular to green axis', (0., 1., 0.)), + ('3d-plane-normal-z', 'Plane 2', + 'Set plane perpendicular to blue axis', (0., 0., 1.)), + ) + + def getSignals(self): + return [self.subject.getCutPlanes()[0].sigPlaneChanged] + + def _pullData(self): + currentNormal = self.subject.getCutPlanes()[0].getNormal() + for _, text, _, normal in self._PLANE_ACTIONS: + if numpy.array_equal(normal, currentNormal): + return text + return '' + + def getEditor(self, parent, option, index): + editor = qt.QComboBox(parent) + for iconName, text, tooltip, normal in self._PLANE_ACTIONS: + editor.addItem(getQIcon(iconName), text) + editor.currentIndexChanged[int].connect(self.__editorChanged) + return editor + + def __editorChanged(self, index): + normal = self._PLANE_ACTIONS[index][3] + plane = self.subject.getCutPlanes()[0] + plane.setNormal(normal) + plane.moveToCenter() + + def setEditorData(self, editor): + currentText = self._pullData() + index = 0 + for normIdx, (_, text, _, _) in enumerate(self._PLANE_ACTIONS): + if text == currentText: + index = normIdx + break + editor.setCurrentIndex(index) + return True + + def _setModelData(self, editor): + return True + + +class PlaneInterpolationItem(SubjectItem): + """Toggle cut plane interpolation method: nearest or linear. + + Item is checkable + """ + + def _init(self): + interpolation = self.subject.getCutPlanes()[0].getInterpolation() + self.setCheckable(True) + self.setCheckState( + qt.Qt.Checked if interpolation == 'linear' else qt.Qt.Unchecked) + self.setData(self._pullData(), role=qt.Qt.DisplayRole, pushData=False) + + def getSignals(self): + return [self.subject.getCutPlanes()[0].sigInterpolationChanged] + + def leftClicked(self): + checked = self.checkState() == qt.Qt.Checked + self._setInterpolation('linear' if checked else 'nearest') + + def _pullData(self): + interpolation = self.subject.getCutPlanes()[0].getInterpolation() + self._setInterpolation(interpolation) + return interpolation[0].upper() + interpolation[1:] + + def _setInterpolation(self, interpolation): + self.subject.getCutPlanes()[0].setInterpolation(interpolation) + + +class PlaneColormapItem(ColormapBase): + """ + colormap name item. + Editor is a QComboBox + """ + editable = True + + listValues = ['gray', 'reversed gray', + 'temperature', 'red', + 'green', 'blue'] + + def getEditor(self, parent, option, index): + editor = qt.QComboBox(parent) + editor.addItems(self.listValues) + editor.currentIndexChanged[int].connect(self.__editorChanged) + + return editor + + def __editorChanged(self, index): + colorMapName = self.listValues[index] + colorMap = self.subject.getCutPlanes()[0].getColormap() + self.subject.getCutPlanes()[0].setColormap(name=colorMapName, + norm=colorMap.getNorm(), + vmin=colorMap.getVMin(), + vmax=colorMap.getVMax()) + + def setEditorData(self, editor): + colormapName = self.subject.getCutPlanes()[0].getColormap().getName() + index = self.listValues.index(colormapName) + editor.setCurrentIndex(index) + return True + + def _setModelData(self, editor): + self.__editorChanged(editor.currentIndex()) + return True + + def _pullData(self): + return self.subject.getCutPlanes()[0].getColormap().getName() + + +class PlaneAutoScaleItem(ColormapBase): + """ + colormap autoscale item. + Item is checkable. + """ + + def _init(self): + colorMap = self.subject.getCutPlanes()[0].getColormap() + self.setCheckable(True) + self.setCheckState((colorMap.isAutoscale() and qt.Qt.Checked) + or qt.Qt.Unchecked) + self.setData(self._pullData(), role=qt.Qt.DisplayRole, pushData=False) + + def leftClicked(self): + checked = (self.checkState() == qt.Qt.Checked) + self._setAutoScale(checked) + + def _setAutoScale(self, auto): + view3d = self.subject + cutPlane = view3d.getCutPlanes()[0] + colormap = cutPlane.getColormap() + + if auto != colormap.isAutoscale(): + if auto: + vMin = vMax = None + else: + dataRange = view3d.getDataRange() + if dataRange is None or None in dataRange: + vMin = vMax = None + else: + vMin, vMax = dataRange + cutPlane.setColormap(colormap.getName(), + colormap.getNorm(), + vMin, + vMax) + + def _pullData(self): + auto = self.subject.getCutPlanes()[0].getColormap().isAutoscale() + self._setAutoScale(auto) + if auto: + data = 'Auto' + else: + data = 'User' + return data + + +class NormalizationNode(ColormapBase): + """ + colormap normalization item. + Item is a QComboBox. + """ + editable = True + listValues = ['linear', 'log'] + + def getEditor(self, parent, option, index): + editor = qt.QComboBox(parent) + editor.addItems(self.listValues) + editor.currentIndexChanged[int].connect(self.__editorChanged) + + return editor + + def __editorChanged(self, index): + colorMap = self.subject.getCutPlanes()[0].getColormap() + normalization = self.listValues[index] + self.subject.getCutPlanes()[0].setColormap(name=colorMap.getName(), + norm=normalization, + vmin=colorMap.getVMin(), + vmax=colorMap.getVMax()) + + def setEditorData(self, editor): + normalization = self.subject.getCutPlanes()[0].getColormap().getNorm() + index = self.listValues.index(normalization) + editor.setCurrentIndex(index) + return True + + def _setModelData(self, editor): + self.__editorChanged(editor.currentIndex()) + return True + + def _pullData(self): + return self.subject.getCutPlanes()[0].getColormap().getNorm() + + +class PlaneGroup(SubjectItem): + """ + Root Item for the plane items. + """ + def _init(self): + valueItem = qt.QStandardItem() + valueItem.setEditable(False) + nameItem = PlaneVisibleItem(self.subject, 'Visible') + self.appendRow([nameItem, valueItem]) + + nameItem = qt.QStandardItem('Colormap') + nameItem.setEditable(False) + valueItem = PlaneColormapItem(self.subject) + self.appendRow([nameItem, valueItem]) + + nameItem = qt.QStandardItem('Normalization') + nameItem.setEditable(False) + valueItem = NormalizationNode(self.subject) + self.appendRow([nameItem, valueItem]) + + nameItem = qt.QStandardItem('Orientation') + nameItem.setEditable(False) + valueItem = PlaneOrientationItem(self.subject) + self.appendRow([nameItem, valueItem]) + + nameItem = qt.QStandardItem('Interpolation') + nameItem.setEditable(False) + valueItem = PlaneInterpolationItem(self.subject) + self.appendRow([nameItem, valueItem]) + + nameItem = qt.QStandardItem('Autoscale') + nameItem.setEditable(False) + valueItem = PlaneAutoScaleItem(self.subject) + self.appendRow([nameItem, valueItem]) + + nameItem = qt.QStandardItem('Min') + nameItem.setEditable(False) + valueItem = PlaneMinRangeItem(self.subject) + self.appendRow([nameItem, valueItem]) + + nameItem = qt.QStandardItem('Max') + nameItem.setEditable(False) + valueItem = PlaneMaxRangeItem(self.subject) + self.appendRow([nameItem, valueItem]) + + +class PlaneVisibleItem(SubjectItem): + """ + Plane visibility item. + Item is checkable. + """ + def _init(self): + plane = self.subject.getCutPlanes()[0] + self.setCheckable(True) + self.setCheckState((plane.isVisible() and qt.Qt.Checked) + or qt.Qt.Unchecked) + + def leftClicked(self): + plane = self.subject.getCutPlanes()[0] + checked = (self.checkState() == qt.Qt.Checked) + if checked != plane.isVisible(): + plane.setVisible(checked) + if plane.isVisible(): + plane.moveToCenter() + + +# Tree ######################################################################## + +class ItemDelegate(qt.QStyledItemDelegate): + """ + Delegate for the QTreeView filled with SubjectItems. + """ + + sigDelegateEvent = qt.Signal(str) + + def __init__(self, parent=None): + super(ItemDelegate, self).__init__(parent) + + def createEditor(self, parent, option, index): + item = index.model().itemFromIndex(index) + if item: + if isinstance(item, SubjectItem): + editor = item.getEditor(parent, option, index) + if editor: + editor.setAutoFillBackground(True) + if hasattr(editor, 'sigViewTask'): + editor.sigViewTask.connect(self.__viewTask) + return editor + + editor = super(ItemDelegate, self).createEditor(parent, + option, + index) + return editor + + def updateEditorGeometry(self, editor, option, index): + editor.setGeometry(option.rect) + + def setEditorData(self, editor, index): + item = index.model().itemFromIndex(index) + if item: + if isinstance(item, SubjectItem) and item.setEditorData(editor): + return + super(ItemDelegate, self).setEditorData(editor, index) + + def setModelData(self, editor, model, index): + item = index.model().itemFromIndex(index) + if isinstance(item, SubjectItem) and item._setModelData(editor): + return + super(ItemDelegate, self).setModelData(editor, model, index) + + def __viewTask(self, task): + self.sigDelegateEvent.emit(task) + + +class TreeView(qt.QTreeView): + """ + TreeView displaying the SubjectItems for the ScalarFieldView. + """ + + def __init__(self, parent=None): + super(TreeView, self).__init__(parent) + self.__openedIndex = None + + self.setIconSize(qt.QSize(16, 16)) + + header = self.header() + if hasattr(header, 'setSectionResizeMode'): # Qt5 + header.setSectionResizeMode(qt.QHeaderView.ResizeToContents) + else: # Qt4 + header.setResizeMode(qt.QHeaderView.ResizeToContents) + + delegate = ItemDelegate() + self.setItemDelegate(delegate) + delegate.sigDelegateEvent.connect(self.__delegateEvent) + self.setSelectionBehavior(qt.QAbstractItemView.SelectRows) + self.setSelectionMode(qt.QAbstractItemView.SingleSelection) + + self.clicked.connect(self.__clicked) + + def setSfView(self, sfView): + """ + Sets the ScalarFieldView this view is controlling. + + :param sfView: A `ScalarFieldView` + """ + model = qt.QStandardItemModel() + model.setColumnCount(ModelColumns.ColumnMax) + model.setHorizontalHeaderLabels(['Name', 'Value']) + + item = qt.QStandardItem() + item.setEditable(False) + model.appendRow([ViewSettingsItem(sfView, 'Style'), item]) + + item = qt.QStandardItem() + item.setEditable(False) + model.appendRow([DataSetItem(sfView, 'Data'), item]) + + item = IsoSurfaceCount(sfView) + item.setEditable(False) + model.appendRow([IsoSurfaceGroup(sfView, 'Isosurfaces'), item]) + + item = qt.QStandardItem() + item.setEditable(False) + model.appendRow([PlaneGroup(sfView, 'Cutting Plane'), item]) + + self.setModel(model) + + def setModel(self, model): + """ + Reimplementation of the QTreeView.setModel method. It connects the + rowsRemoved signal and opens the persistent editors. + + :param qt.QStandardItemModel model: the model + """ + + prevModel = self.model() + if prevModel: + self.__openPersistentEditors(qt.QModelIndex(), False) + try: + prevModel.rowsRemoved.disconnect(self.rowsRemoved) + except TypeError: + pass + + super(TreeView, self).setModel(model) + model.rowsRemoved.connect(self.rowsRemoved) + self.__openPersistentEditors(qt.QModelIndex()) + + def __openPersistentEditors(self, parent=None, openEditor=True): + """ + Opens or closes the items persistent editors. + + :param qt.QModelIndex parent: starting index, or None if the whole tree + is to be considered. + :param bool openEditor: True to open the editors, False to close them. + """ + model = self.model() + + if not model: + return + + if not parent or not parent.isValid(): + parent = self.model().invisibleRootItem().index() + + if openEditor: + meth = self.openPersistentEditor + else: + meth = self.closePersistentEditor + + curParent = parent + children = [model.index(row, 0, curParent) + for row in range(model.rowCount(curParent))] + + columnCount = model.columnCount() + + while len(children) > 0: + curParent = children.pop(-1) + + children.extend([model.index(row, 0, curParent) + for row in range(model.rowCount(curParent))]) + + for colIdx in range(columnCount): + sibling = model.sibling(curParent.row(), + colIdx, + curParent) + item = model.itemFromIndex(sibling) + if isinstance(item, SubjectItem) and item.persistent: + meth(sibling) + + def rowsAboutToBeRemoved(self, parent, start, end): + """ + Reimplementation of the QTreeView.rowsAboutToBeRemoved. Closes all + persistent editors under parent. + + :param qt.QModelIndex parent: Parent index + :param int start: Start index from parent index (inclusive) + :param int end: End index from parent index (inclusive) + """ + self.__openPersistentEditors(parent, False) + super(TreeView, self).rowsAboutToBeRemoved(parent, start, end) + + def rowsRemoved(self, parent, start, end): + """ + Called when QTreeView.rowsRemoved is emitted. Opens all persistent + editors under parent. + + :param qt.QModelIndex parent: Parent index + :param int start: Start index from parent index (inclusive) + :param int end: End index from parent index (inclusive) + """ + super(TreeView, self).rowsRemoved(parent, start, end) + self.__openPersistentEditors(parent, True) + + def rowsInserted(self, parent, start, end): + """ + Reimplementation of the QTreeView.rowsInserted. Opens all persistent + editors under parent. + + :param qt.QModelIndex parent: Parent index + :param int start: Start index from parent index + :param int end: End index from parent index + """ + self.__openPersistentEditors(parent, False) + super(TreeView, self).rowsInserted(parent, start, end) + self.__openPersistentEditors(parent) + + def keyReleaseEvent(self, event): + """ + Reimplementation of the QTreeView.keyReleaseEvent. + At the moment only Key_Delete is handled. It calls the selected item's + queryRemove method, and deleted the item if needed. + + :param qt.QKeyEvent event: A key event + """ + + # TODO : better filtering + key = event.key() + modifiers = event.modifiers() + + if key == qt.Qt.Key_Delete and modifiers == qt.Qt.NoModifier: + self.__removeIsosurfaces() + + super(TreeView, self).keyReleaseEvent(event) + + def __removeIsosurfaces(self): + model = self.model() + selected = self.selectedIndexes() + items = [] + # WARNING : the selection mode is set to single, so we re not + # supposed to have more than one item here. + # Multiple selection deletion has not been tested. + # Watch out for index invalidation + for index in selected: + leftIndex = model.sibling(index.row(), 0, index) + leftItem = model.itemFromIndex(leftIndex) + if isinstance(leftItem, SubjectItem) and leftItem not in items: + items.append(leftItem) + + isos = [item for item in items if isinstance(item, IsoSurfaceRootItem)] + if isos: + for iso in isos: + if iso.queryRemove(self): + parentItem = iso.parent() + parentItem.removeRow(iso.row()) + else: + qt.QMessageBox.information( + self, + 'Remove isosurface', + 'Select an iso-surface to remove it') + + def __clicked(self, index): + """ + Called when the QTreeView.clicked signal is emitted. Calls the item's + leftClick method. + + :param qt.QIndex index: An index + """ + item = self.model().itemFromIndex(index) + if isinstance(item, SubjectItem): + item.leftClicked() + + def __delegateEvent(self, task): + if task == 'remove_iso': + self.__removeIsosurfaces() diff --git a/silx/gui/plot3d/ScalarFieldView.py b/silx/gui/plot3d/ScalarFieldView.py new file mode 100644 index 0000000..2eb54a3 --- /dev/null +++ b/silx/gui/plot3d/ScalarFieldView.py @@ -0,0 +1,1385 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a window to view a 3D scalar field. + +It supports iso-surfaces, a cutting plane and the definition of +a region of interest. +""" + +from __future__ import absolute_import + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "10/01/2017" + +import re +import logging +import time +from collections import deque + +import numpy + +from silx.gui import qt +from silx.gui.plot.Colors import rgba + +from silx.math.marchingcubes import MarchingCubes + +from .scene import axes, cutplane, function, interaction, primitives, transform +from . import scene +from .Plot3DWindow import Plot3DWindow + + +_logger = logging.getLogger(__name__) + + +class _BoundedGroup(scene.Group): + """Group with data bounds""" + + _shape = None # To provide a default value without overriding __init__ + + @property + def shape(self): + """Data shape (depth, height, width) of this group or None""" + return self._shape + + @shape.setter + def shape(self, shape): + if shape is None: + self._shape = None + else: + depth, height, width = shape + self._shape = float(depth), float(height), float(width) + + @property + def size(self): + """Data size (width, height, depth) of this group or None""" + shape = self.shape + if shape is None: + return None + else: + return shape[2], shape[1], shape[0] + + @size.setter + def size(self, size): + if size is None: + self.shape = None + else: + self.shape = size[2], size[1], size[0] + + def _bounds(self, dataBounds=False): + if dataBounds and self.size is not None: + return numpy.array(((0., 0., 0.), self.size), + dtype=numpy.float32) + else: + return super(_BoundedGroup, self)._bounds(dataBounds) + + +class Isosurface(qt.QObject): + """Class representing an iso-surface + + :param parent: The View widget this iso-surface belongs to + """ + + sigLevelChanged = qt.Signal(float) + """Signal emitted when the iso-surface level has changed. + + This signal provides the new level value (might be nan). + """ + + sigColorChanged = qt.Signal() + """Signal emitted when the iso-surface color has changed""" + + sigVisibilityChanged = qt.Signal(bool) + """Signal emitted when the iso-surface visibility has changed. + + This signal provides the new visibility status. + """ + + def __init__(self, parent): + super(Isosurface, self).__init__(parent=parent) + self._level = float('nan') + self._autoLevelFunction = None + self._color = rgba('#FFD700FF') + self._data = None + self._group = scene.Group() + + def _setData(self, data, copy=True): + """Set the data set from which to build the iso-surface. + + :param numpy.ndarray data: The 3D dataset or None + :param bool copy: True to make a copy, False to use as is if possible + """ + if data is None: + self._data = None + else: + self._data = numpy.array(data, copy=copy, order='C') + + self._update() + + def _get3DPrimitive(self): + """Return the group containing the mesh of the iso-surface if any""" + return self._group + + def isVisible(self): + """Returns True if iso-surface is visible, else False""" + return self._group.visible + + def setVisible(self, visible): + """Set the visibility of the iso-surface in the view. + + :param bool visible: True to show the iso-surface, False to hide + """ + visible = bool(visible) + if visible != self._group.visible: + self._group.visible = visible + self.sigVisibilityChanged.emit(visible) + + def getLevel(self): + """Return the level of this iso-surface (float)""" + return self._level + + def setLevel(self, level): + """Set the value at which to build the iso-surface. + + Setting this value reset auto-level function + + :param float level: The value at which to build the iso-surface + """ + self._autoLevelFunction = None + level = float(level) + if level != self._level: + self._level = level + self._update() + self.sigLevelChanged.emit(level) + + def isAutoLevel(self): + """True if iso-level is rebuild for each data set.""" + return self.getAutoLevelFunction() is not None + + def getAutoLevelFunction(self): + """Return the function computing the iso-level (callable or None)""" + return self._autoLevelFunction + + def setAutoLevelFunction(self, autoLevel): + """Set the function used to compute the iso-level. + + WARNING: The function might get called in a thread. + + :param callable autoLevel: + A function taking a 3D numpy.ndarray of float32 and returning + a float used as iso-level. + Example: numpy.mean(data) + numpy.std(data) + """ + assert callable(autoLevel) + self._autoLevelFunction = autoLevel + self._update() + + def getColor(self): + """Return the color of this iso-surface (QColor)""" + return qt.QColor.fromRgbF(*self._color) + + def setColor(self, color): + """Set the color of the iso-surface + + :param color: RGBA color of the isosurface + :type color: QColor, str or array-like of 4 float in [0., 1.] + """ + color = rgba(color) + if color != self._color: + self._color = color + if len(self._group.children) != 0: + self._group.children[0].setAttribute('color', self._color) + self.sigColorChanged.emit() + + def _update(self): + """Update underlying mesh""" + self._group.children = [] + + if self._data is None: + if self.isAutoLevel(): + self._level = float('nan') + + else: + if self.isAutoLevel(): + st = time.time() + try: + level = float(self.getAutoLevelFunction()(self._data)) + + except Exception: + module = self.getAutoLevelFunction().__module__ + name = self.getAutoLevelFunction().__name__ + _logger.error( + "Error while executing iso level function %s.%s", + module, + name, + exc_info=True) + level = float('nan') + + else: + _logger.info( + 'Computed iso-level in %f s.', time.time() - st) + + if level != self._level: + self._level = level + self.sigLevelChanged.emit(level) + + if numpy.isnan(self._level): + return + + st = time.time() + vertices, normals, indices = MarchingCubes( + self._data, + isolevel=self._level) + _logger.info('Computed iso-surface in %f s.', time.time() - st) + + if len(vertices) == 0: + return + else: + mesh = primitives.Mesh3D(vertices, + colors=self._color, + normals=normals, + mode='triangles', + indices=indices) + self._group.children = [mesh] + + +class Colormap(object): + """Description of a colormap + + :param str name: Name of the colormap + :param str norm: Normalization: 'linear' (default) or 'log' + :param float vmin: + Lower bound of the colormap or None for autoscale (default) + :param float vmax: + Upper bounds of the colormap or None for autoscale (default) + """ + + def __init__(self, name, norm='linear', vmin=None, vmax=None): + assert name in function.Colormap.COLORMAPS + self._name = str(name) + + assert norm in ('linear', 'log') + self._norm = str(norm) + + self._vmin = float(vmin) if vmin is not None else None + self._vmax = float(vmax) if vmax is not None else None + + def isAutoscale(self): + """True if both min and max are in autoscale mode""" + return self._vmin is None or self._vmax is None + + def getName(self): + """Return the name of the colormap (str)""" + return self._name + + def getNorm(self): + """Return the normalization of the colormap (str)""" + return self._norm + + def getVMin(self): + """Return the lower bound of the colormap or None""" + return self._vmin + + def getVMax(self): + """Return the upper bounds of the colormap or None""" + return self._vmax + + +class SelectedRegion(object): + """Selection of a 3D region aligned with the axis. + + :param arrayRange: Range of the selection in the array + ((zmin, zmax), (ymin, ymax), (xmin, xmax)) + :param translation: Offset from array to data coordinates (ox, oy, oz) + :param scale: Scale from array to data coordinates (sx, sy, sz) + """ + + def __init__(self, arrayRange, + translation=(0., 0., 0.), + scale=(1., 1., 1.)): + self._arrayRange = numpy.array(arrayRange, copy=True, dtype=numpy.int) + assert self._arrayRange.shape == (3, 2) + assert numpy.all(self._arrayRange[:, 1] >= self._arrayRange[:, 0]) + self._translation = numpy.array(translation, dtype=numpy.float32) + assert self._translation.shape == (3,) + self._scale = numpy.array(scale, dtype=numpy.float32) + assert self._scale.shape == (3,) + + self._dataRange = (self._translation.reshape(3, -1) + + self._arrayRange[::-1] * self._scale.reshape(3, -1)) + + def getArrayRange(self): + """Returns array ranges of the selection: 3x2 array of int + + :return: A numpy array with ((zmin, zmax), (ymin, ymax), (xmin, xmax)) + :rtype: numpy.ndarray + """ + return self._arrayRange.copy() + + def getArraySlices(self): + """Slices corresponding to the selected range in the array + + :return: A numpy array with (zslice, yslice, zslice) + :rtype: numpy.ndarray + """ + return (slice(*self._arrayRange[0]), + slice(*self._arrayRange[1]), + slice(*self._arrayRange[2])) + + def getDataRange(self): + """Range in the data coordinates of the selection: 3x2 array of float + + :return: A numpy array with ((xmin, xmax), (ymin, ymax), (zmin, zmax)) + :rtype: numpy.ndarray + """ + return self._dataRange.copy() + + def getDataScale(self): + """Scale from array to data coordinates: (sx, sy, sz) + + :return: A numpy array with (sx, sy, sz) + :rtype: numpy.ndarray + """ + return self._scale.copy() + + def getDataTranslation(self): + """Offset from array to data coordinates: (ox, oy, oz) + + :return: A numpy array with (ox, oy, oz) + :rtype: numpy.ndarray + """ + return self._translation.copy() + + +class CutPlane(qt.QObject): + """Class representing a cutting plane + + :param ScalarFieldView sfView: Widget in which the cut plane is applied. + """ + + sigVisibilityChanged = qt.Signal(bool) + """Signal emitted when the cut visibility has changed. + + This signal provides the new visibility status. + """ + + sigDataChanged = qt.Signal() + """Signal emitted when the data this plane is cutting has changed.""" + + sigPlaneChanged = qt.Signal() + """Signal emitted when the cut plane has moved""" + + sigColormapChanged = qt.Signal(object) + """Signal emitted when the colormap has changed + + This signal provides the new colormap. + """ + + sigInterpolationChanged = qt.Signal(str) + """Signal emitted when the cut plane interpolation has changed + + This signal provides the new interpolation mode. + """ + + def __init__(self, sfView): + super(CutPlane, self).__init__(parent=sfView) + + self._colormap = Colormap( + name='gray', norm='linear', vmin=None, vmax=None) + + self._dataRange = None + self._positiveMin = None + + self._plane = cutplane.CutPlane(normal=(0, 1, 0)) + self._plane.alpha = 1. + self._plane.visible = self._visible = False + self._plane.addListener(self._planeChanged) + self._plane.plane.addListener(self._planePositionChanged) + + sfView.sigDataChanged.connect(self._sfViewDataChanged) + + def _get3DPrimitive(self): + """Return the cut plane scene node""" + return self._plane + + def _sfViewDataChanged(self): + """Handle data change in the ScalarFieldView this plane belongs to""" + self._plane.setData(self.sender().getData(), copy=False) + self._dataRange = self.sender().getDataRange() + self._positiveMin = None + self.sigDataChanged.emit() + + # Update colormap range when autoscale + if self.getColormap().isAutoscale(): + self._updateColormapRange() + + def _planeChanged(self, source, *args, **kwargs): + """Handle events from the plane primitive""" + # Using _visible for now, until scene as more info in events + if source.visible != self._visible: + self._visible = source.visible + self.sigVisibilityChanged.emit(source.visible) + + def _planePositionChanged(self, source, *args, **kwargs): + """Handle update of cut plane position and normal""" + if self._plane.visible: + self.sigPlaneChanged.emit() + + # Plane position + + def moveToCenter(self): + """Move cut plane to center of data set""" + self._plane.moveToCenter() + + def isValid(self): + """Returns whether the cut plane is defined or not (bool)""" + return self._plane.isValid + + def getNormal(self): + """Returns the normal of the plane (as a unit vector) + + :return: Normal (nx, ny, nz), vector is 0 if no plane is defined + :rtype: numpy.ndarray + """ + return self._plane.plane.normal + + def setNormal(self, normal): + """Set the normal of the plane + + :param normal: 3-tuple of float: nx, ny, nz + """ + self._plane.plane.normal = normal + + def getPoint(self): + """Returns a point on the plane + + :return: (x, y, z) + :rtype: numpy.ndarray + """ + return self._plane.plane.point + + def getParameters(self): + """Returns the plane equation parameters: a*x + b*y + c*z + d = 0 + + :return: Plane equation parameters: (a, b, c, d) + :rtype: numpy.ndarray + """ + return self._plane.plane.parameters + + # Visibility + + def isVisible(self): + """Returns True if the plane is visible, False otherwise""" + return self._plane.visible + + def setVisible(self, visible): + """Set the visibility of the plane + + :param bool visible: True to make plane visible + """ + self._plane.visible = visible + + # Border stroke + + def getStrokeColor(self): + """Returns the color of the plane border (QColor)""" + return qt.QColor.fromRgbF(*self._plane.color) + + def setStrokeColor(self, color): + """Set the color of the plane border. + + :param color: RGB color: name, #RRGGBB or RGB values + :type color: + QColor, str or array-like of 3 or 4 float in [0., 1.] or uint8 + """ + self._plane.color = rgba(color) + + # Data + + def getImageData(self): + """Returns the data and information corresponding to the cut plane. + + The returned data is not interpolated, + it is a slice of the 3D scalar field. + + Image data axes are so that plane normal is towards the point of view. + + :return: An object containing the 2D data slice and information + """ + return _CutPlaneImage(self) + + # Interpolation + + def getInterpolation(self): + """Returns the interpolation used to display to cut plane. + + :return: 'nearest' or 'linear' + :rtype: str + """ + return self._plane.interpolation + + def setInterpolation(self, interpolation): + """Set the interpolation used to display to cut plane + + The default interpolation is 'linear' + + :param str interpolation: 'nearest' or 'linear' + """ + if interpolation != self.getInterpolation(): + self._plane.interpolation = interpolation + self.sigInterpolationChanged.emit(interpolation) + + # Colormap + + # def getAlpha(self): + # """Returns the transparency of the plane as a float in [0., 1.]""" + # return self._plane.alpha + + # def setAlpha(self, alpha): + # """Set the plane transparency. + # + # :param float alpha: Transparency in [0., 1] + # """ + # self._plane.alpha = alpha + + def getColormap(self): + """Returns the colormap set by :meth:`getColormap`. + + :return: The colormap + :rtype: Colormap + """ + return self._colormap + + def setColormap(self, + name='gray', + norm='linear', + vmin=None, + vmax=None): + """Set the colormap to use. + + :param str name: Name of the colormap in + 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'. + :param str norm: Colormap mapping: 'linear' or 'log'. + :param float vmin: The minimum value of the range or None for autoscale + :param float vmax: The maximum value of the range or None for autoscale + """ + _logger.debug('setColormap %s %s (%s, %s)', + name, norm, str(vmin), str(vmax)) + + self._colormap = Colormap( + name=name, norm=norm, vmin=vmin, vmax=vmax) + + self._updateColormapRange() + self.sigColormapChanged.emit(self.getColormap()) + + def getColormapEffectiveRange(self): + """Returns the currently used range of the colormap. + + This range is computed from the data set if colormap is in autoscale. + Range is clipped to positive values when using log scale. + + :return: 2-tuple of float + """ + return self._plane.colormap.range_ + + def _updateColormapRange(self): + """Update the colormap range""" + colormap = self.getColormap() + + self._plane.colormap.name = colormap.getName() + if colormap.isAutoscale(): + range_ = self._dataRange + if range_ is None: # No data, use a default range + range_ = 1., 10. + else: + range_ = colormap.getVMin(), colormap.getVMax() + + if colormap.getNorm() == 'linear': + self._plane.colormap.norm = 'linear' + self._plane.colormap.range_ = range_ + + else: # Log + # Make sure range is strictly positive + if range_[0] <= 0.: + data = self._plane.getData(copy=False) + if data is not None: + if self._positiveMin is None: + # TODO compute this with the range as a combo operation + self._positiveMin = numpy.min(data[data > 0.]) + range_ = (self._positiveMin, + max(range_[1], self._positiveMin)) + + self._plane.colormap.range_ = range_ + self._plane.colormap.norm = colormap.getNorm() + + +class _CutPlaneImage(object): + """Object representing the data sliced by a cut plane + + :param CutPlane cutPlane: The CutPlane from which to generate image info + """ + + def __init__(self, cutPlane): + # Init attributes with default values + self._isValid = False + self._data = numpy.array([]) + self._xLabel = '' + self._yLabel = '' + self._normalLabel = '' + self._scale = 1., 1. + self._translation = 0., 0. + self._index = 0 + self._position = 0. + + sfView = cutPlane.parent() + if not sfView or not cutPlane.isValid(): + _logger.info("No plane available") + return + + data = sfView.getData(copy=False) + if data is None: + _logger.info("No data available") + return + + normal = cutPlane.getNormal() + point = numpy.array(cutPlane.getPoint(), dtype=numpy.int) + + if numpy.all(numpy.equal(normal, (1., 0., 0.))): + index = max(0, min(point[0], data.shape[2] - 1)) + slice_ = data[:, :, index] + xAxisIndex, yAxisIndex, normalAxisIndex = 1, 2, 0 # y, z, x + elif numpy.all(numpy.equal(normal, (0., 1., 0.))): + index = max(0, min(point[1], data.shape[1] - 1)) + slice_ = numpy.transpose(data[:, index, :]) + xAxisIndex, yAxisIndex, normalAxisIndex = 2, 0, 1 # z, x, y + elif numpy.all(numpy.equal(normal, (0., 0., 1.))): + index = max(0, min(point[2], data.shape[0] - 1)) + slice_ = data[index, :, :] + xAxisIndex, yAxisIndex, normalAxisIndex = 0, 1, 2 # x, y, z + else: + _logger.warning('Unsupported normal: (%f, %f, %f)', + normal[0], normal[1], normal[2]) + return + + # Store cut plane image info + + self._isValid = True + self._data = numpy.array(slice_, copy=True) + + labels = sfView.getAxesLabels() + scale = sfView.getScale() + translation = sfView.getTranslation() + + self._xLabel = labels[xAxisIndex] + self._yLabel = labels[yAxisIndex] + self._normalLabel = labels[normalAxisIndex] + + self._scale = scale[xAxisIndex], scale[yAxisIndex] + self._translation = translation[xAxisIndex], translation[yAxisIndex] + + self._index = index + self._position = float(index * scale[normalAxisIndex] + + translation[normalAxisIndex]) + + def isValid(self): + """Returns True if the cut plane image is defined (bool)""" + return self._isValid + + def getData(self, copy=True): + """Returns the image data sliced by the cut plane. + + :param bool copy: True to get a copy, False otherwise + :return: The 2D image data corresponding to the cut plane + :rtype: numpy.ndarray + """ + return numpy.array(self._data, copy=copy) + + def getXLabel(self): + """Returns the label associated to the X axis of the image (str)""" + return self._xLabel + + def getYLabel(self): + """Returns the label associated to the Y axis of the image (str)""" + return self._yLabel + + def getNormalLabel(self): + """Returns the label of the 3D axis of the plane normal (str)""" + return self._normalLabel + + def getScale(self): + """Returns the scales of the data as a 2-tuple of float (sx, sy)""" + return self._scale + + def getTranslation(self): + """Returns the offset of the data as a 2-tuple of float (ox, oy)""" + return self._translation + + def getIndex(self): + """Returns the index in the data array of the cut plane (int)""" + return self._index + + def getPosition(self): + """Returns the cut plane position along the normal axis (flaot)""" + return self._position + + +class ScalarFieldView(Plot3DWindow): + """Widget computing and displaying an iso-surface from a 3D scalar dataset. + + Limitation: Currently, iso-surfaces are generated with higher values + than the iso-level 'inside' the surface. + + :param parent: See :class:`QMainWindow` + """ + + sigDataChanged = qt.Signal() + """Signal emitted when the scalar data field has changed.""" + + sigSelectedRegionChanged = qt.Signal(object) + """Signal emitted when the selected region has changed. + + This signal provides the new selected region. + """ + + def __init__(self, parent=None): + super(ScalarFieldView, self).__init__(parent) + self._colormap = Colormap( + name='gray', norm='linear', vmin=None, vmax=None) + self._selectedRange = None + + # Store iso-surfaces + self._isosurfaces = [] + + # Transformations + self._dataScale = transform.Scale() + self._dataTranslate = transform.Translate() + + self._foregroundColor = 1., 1., 1., 1. + self._highlightColor = 0.7, 0.7, 0., 1. + + self._data = None + self._dataRange = None + + self._group = _BoundedGroup() + self._group.transforms = [self._dataTranslate, self._dataScale] + + self._selectionBox = primitives.Box() + self._selectionBox.strokeSmooth = False + self._selectionBox.strokeWidth = 1. + # self._selectionBox.fillColor = 1., 1., 1., 0.3 + # self._selectionBox.fillCulling = 'back' + self._selectionBox.visible = False + self._group.children.append(self._selectionBox) + + self._cutPlane = CutPlane(sfView=self) + self._cutPlane.sigVisibilityChanged.connect( + self._planeVisibilityChanged) + self._group.children.append(self._cutPlane._get3DPrimitive()) + + self._isogroup = primitives.GroupDepthOffset() + self._isogroup.transforms = [ + # Convert from z, y, x from marching cubes to x, y, z + transform.Matrix(( + (0., 0., 1., 0.), + (0., 1., 0., 0.), + (1., 0., 0., 0.), + (0., 0., 0., 1.))), + # Offset to match cutting plane coords + transform.Translate(0.5, 0.5, 0.5) + ] + self._group.children.append(self._isogroup) + + self._bbox = axes.LabelledAxes() + self._bbox.children = [self._group] + self.getPlot3DWidget().viewport.scene.children.append(self._bbox) + + self._initInteractionToolBar() + + self._updateColors() + + self.getPlot3DWidget().viewport.light.shininess = 32 + + def saveConfig(self, ioDevice): + """ + Saves this view state. Only isosurfaces at the moment. Does not save + the isosurface's function. + + :param qt.QIODevice ioDevice: A `qt.QIODevice`. + """ + + stream = qt.QDataStream(ioDevice) + + stream.writeString('<ScalarFieldView>') + + isoSurfaces = self.getIsosurfaces() + + nIsoSurfaces = len(isoSurfaces) + + # TODO : delegate the serialization to the serialized items + # isosurfaces + if nIsoSurfaces: + tagIn = '<IsoSurfaces nIso={0}>'.format(nIsoSurfaces) + stream.writeString(tagIn) + + for surface in isoSurfaces: + color = surface.getColor() + level = surface.getLevel() + visible = surface.isVisible() + stream << color + stream.writeDouble(level) + stream.writeBool(visible) + + stream.writeString('</IsoSurfaces>') + + stream.writeString('<Style>') + background = self.getBackgroundColor() + foreground = self.getForegroundColor() + highlight = self.getHighlightColor() + stream << background << foreground << highlight + stream.writeString('</Style>') + + stream.writeString('</ScalarFieldView>') + + def loadConfig(self, ioDevice): + """ + Loads this view state. + See ScalarFieldView.saveView to know what is supported at the moment. + + :param qt.QIODevice ioDevice: A `qt.QIODevice`. + """ + + tagStack = deque() + + tagInRegex = re.compile('<(?P<itemId>[^ /]*) *' + '(?P<args>.*)>') + + tagOutRegex = re.compile('</(?P<itemId>[^ ]*)>') + + tagRootInRegex = re.compile('<ScalarFieldView>') + + isoSurfaceArgsRegex = re.compile('nIso=(?P<nIso>[0-9]*)') + + stream = qt.QDataStream(ioDevice) + + tag = stream.readString() + tagMatch = tagRootInRegex.match(tag) + + if tagMatch is None: + # TODO : explicit error + raise ValueError('Unknown data.') + + itemId = 'ScalarFieldView' + + tagStack.append(itemId) + + while True: + + tag = stream.readString() + + tagMatch = tagOutRegex.match(tag) + if tagMatch: + closeId = tagMatch.groupdict()['itemId'] + if closeId != itemId: + # TODO : explicit error + raise ValueError('Unexpected closing tag {0} ' + '(expected {1})' + ''.format(closeId, itemId)) + + if itemId == 'ScalarFieldView': + # reached end + break + else: + itemId = tagStack.pop() + # fetching next tag + continue + + tagMatch = tagInRegex.match(tag) + + if tagMatch is None: + # TODO : explicit error + raise ValueError('Unknown data.') + + tagStack.append(itemId) + + matchDict = tagMatch.groupdict() + + itemId = matchDict['itemId'] + + # TODO : delegate the deserialization to the serialized items + if itemId == 'IsoSurfaces': + argsMatch = isoSurfaceArgsRegex.match(matchDict['args']) + if not argsMatch: + # TODO : explicit error + raise ValueError('Failed to parse args "{0}".' + ''.format(matchDict['args'])) + argsDict = argsMatch.groupdict() + nIso = int(argsDict['nIso']) + if nIso: + for surface in self.getIsosurfaces(): + self.removeIsosurface(surface) + for isoIdx in range(nIso): + color = qt.QColor() + stream >> color + level = stream.readDouble() + visible = stream.readBool() + surface = self.addIsosurface(level, color=color) + surface.setVisible(visible) + elif itemId == 'Style': + background = qt.QColor() + foreground = qt.QColor() + highlight = qt.QColor() + stream >> background >> foreground >> highlight + self.setBackgroundColor(background) + self.setForegroundColor(foreground) + self.setHighlightColor(highlight) + else: + raise ValueError('Unknown entry tag {0}.' + ''.format(itemId)) + + def _initInteractionToolBar(self): + self._interactionToolbar = qt.QToolBar() + self._interactionToolbar.setEnabled(False) + + group = qt.QActionGroup(self._interactionToolbar) + group.setExclusive(True) + + self._cameraAction = qt.QAction(None) + self._cameraAction.setText('camera') + self._cameraAction.setCheckable(True) + self._cameraAction.setToolTip('Control camera') + self._cameraAction.setChecked(True) + group.addAction(self._cameraAction) + + self._planeAction = qt.QAction(None) + self._planeAction.setText('plane') + self._planeAction.setCheckable(True) + self._planeAction.setToolTip('Control cutting plane') + group.addAction(self._planeAction) + group.triggered.connect(self._interactionChanged) + + self._interactionToolbar.addActions(group.actions()) + self.addToolBar(self._interactionToolbar) + + def _planeVisibilityChanged(self, visible): + """Handle visibility events from the plane""" + if visible != self._interactionToolbar.isEnabled(): + if visible: + self._interactionToolbar.setEnabled(True) + self.setInteractiveMode('plane') + else: + self._interactionToolbar.setEnabled(False) + self.setInteractiveMode('camera') + + def _interactionChanged(self, action): + self.setInteractiveMode(action.text()) + + def setInteractiveMode(self, mode): + """Choose the current interaction. + + :param str mode: Either plane or camera + """ + if mode == self.getInteractiveMode(): + return + + sceneScale = self.getPlot3DWidget().viewport.scene.transforms[0] + if mode == 'plane': + self.getPlot3DWidget().eventHandler = \ + interaction.PanPlaneRotateCameraControl( + self.getPlot3DWidget().viewport, + self._cutPlane._get3DPrimitive(), + mode='position', + scaleTransform=sceneScale) + self._planeAction.setChecked(True) + elif mode == 'camera': + self.getPlot3DWidget().eventHandler = interaction.CameraControl( + self.getPlot3DWidget().viewport, orbitAroundCenter=False, + mode='position', scaleTransform=sceneScale, + selectCB=None) + self._cameraAction.setChecked(True) + else: + raise ValueError('Unsupported interactive mode %s', str(mode)) + self._updateColors() + + def getInteractiveMode(self): + """Returns the current interaction mode, see :meth:`setInteractiveMode` + """ + if isinstance(self.getPlot3DWidget().eventHandler, + interaction.PanPlaneRotateCameraControl): + return 'plane' + elif isinstance(self.getPlot3DWidget().eventHandler, + interaction.CameraControl): + return 'camera' + else: + raise RuntimeError('Unknown interactive mode') + + # Handle scalar field + + def setData(self, data, copy=True): + """Set the 3D scalar data set to use for building the iso-surface. + + Dataset order is zyx (i.e., first dimension is z). + + :param data: scalar field from which to extract the iso-surface + :type data: 3D numpy.ndarray of float32 with shape at least (2, 2, 2) + :param bool copy: + True (default) to make a copy, + False to avoid copy (DO NOT MODIFY data afterwards) + """ + if data is None: + self._data = None + self._dataRange = None + self.setSelectedRegion(zrange=None, yrange=None, xrange_=None) + self._group.shape = None + self.centerScene() + + else: + data = numpy.array(data, copy=copy, dtype=numpy.float32, order='C') + assert data.ndim == 3 + assert min(data.shape) >= 2 + + wasData = self._data is not None + previousSelectedRegion = self.getSelectedRegion() + + self._data = data + self._dataRange = self._data.min(), self._data.max() + + if previousSelectedRegion is not None: + # Update selected region to ensure it is clipped to array range + self.setSelectedRegion(*previousSelectedRegion.getArrayRange()) + + self._group.shape = self._data.shape + + if not wasData: + self.centerScene() # Reset viewpoint the first time only + + # Update iso-surfaces + for isosurface in self.getIsosurfaces(): + isosurface._setData(self._data, copy=False) + + self.sigDataChanged.emit() + + def getData(self, copy=True): + """Get the 3D scalar data currently used to build the iso-surface. + + :param bool copy: + True (default) to get a copy, + False to get the internal data (DO NOT modify!) + :return: The data set (or None if not set) + """ + if self._data is None: + return None + else: + return numpy.array(self._data, copy=copy) + + def getDataRange(self): + """Return the range of the data as a 2-tuple (min, max)""" + return self._dataRange + + # Transformations + + def setScale(self, sx=1., sy=1., sz=1.): + """Set the scale of the 3D scalar field (i.e., size of a voxel). + + :param float sx: Scale factor along the X axis + :param float sy: Scale factor along the Y axis + :param float sz: Scale factor along the Z axis + """ + scale = numpy.array((sx, sy, sz), dtype=numpy.float32) + if not numpy.all(numpy.equal(scale, self.getScale())): + self._dataScale.scale = scale + self.centerScene() # Reset viewpoint + + def getScale(self): + """Returns the scales provided by :meth:`setScale` as a numpy.ndarray. + """ + return self._dataScale.scale + + def setTranslation(self, x=0., y=0., z=0.): + """Set the translation of the origin of the data array in data coordinates. + + :param float x: Offset of the data origin on the X axis + :param float y: Offset of the data origin on the Y axis + :param float z: Offset of the data origin on the Z axis + """ + translation = numpy.array((x, y, z), dtype=numpy.float32) + if not numpy.all(numpy.equal(translation, self.getTranslation())): + self._dataTranslate.translation = translation + self.centerScene() # Reset viewpoint + + def getTranslation(self): + """Returns the offset set by :meth:`setTranslation` as a numpy.ndarray. + """ + return self._dataTranslate.translation + + # Axes labels + + def setAxesLabels(self, xlabel=None, ylabel=None, zlabel=None): + """Set the text labels of the axes. + + :param str xlabel: Label of the X axis, None to leave unchanged. + :param str ylabel: Label of the Y axis, None to leave unchanged. + :param str zlabel: Label of the Z axis, None to leave unchanged. + """ + if xlabel is not None: + self._bbox.xlabel = xlabel + + if ylabel is not None: + self._bbox.ylabel = ylabel + + if zlabel is not None: + self._bbox.zlabel = zlabel + + class _Labels(tuple): + """Return type of :meth:`getAxesLabels`""" + + def getXLabel(self): + """Label of the X axis (str)""" + return self[0] + + def getYLabel(self): + """Label of the Y axis (str)""" + return self[1] + + def getZLabel(self): + """Label of the Z axis (str)""" + return self[2] + + def getAxesLabels(self): + """Returns the text labels of the axes + + >>> widget = ScalarFieldView() + >>> widget.setAxesLabels(xlabel='X') + + You can get the labels either as a 3-tuple: + + >>> xlabel, ylabel, zlabel = widget.getAxesLabels() + + Or as an object with methods getXLabel, getYLabel and getZLabel: + + >>> labels = widget.getAxesLabels() + >>> labels.getXLabel() + ... 'X' + + :return: object describing the labels + """ + return self._Labels((self._bbox.xlabel, + self._bbox.ylabel, + self._bbox.zlabel)) + + # Colors + + def _updateColors(self): + """Update item depending on foreground/highlight color""" + self._bbox.tickColor = self._foregroundColor + self._selectionBox.strokeColor = self._foregroundColor + if self.getInteractiveMode() == 'plane': + self._cutPlane.setStrokeColor(self._highlightColor) + self._bbox.color = self._foregroundColor + else: + self._cutPlane.setStrokeColor(self._foregroundColor) + self._bbox.color = self._highlightColor + + def getForegroundColor(self): + """Return color used for text and bounding box (QColor)""" + return qt.QColor.fromRgbF(*self._foregroundColor) + + def setForegroundColor(self, color): + """Set the foreground color. + + :param color: RGB color: name, #RRGGBB or RGB values + :type color: + QColor, str or array-like of 3 or 4 float in [0., 1.] or uint8 + """ + color = rgba(color) + if color != self._foregroundColor: + self._foregroundColor = color + self._updateColors() + + def getHighlightColor(self): + """Return color used for highlighted item bounding box (QColor)""" + return qt.QColor.fromRgbF(*self._highlightColor) + + def setHighlightColor(self, color): + """Set hightlighted item color. + + :param color: RGB color: name, #RRGGBB or RGB values + :type color: + QColor, str or array-like of 3 or 4 float in [0., 1.] or uint8 + """ + color = rgba(color) + if color != self._highlightColor: + self._highlightColor = color + self._updateColors() + + # Cut Plane + + def getCutPlanes(self): + """Return an iterable of all cut planes of the view. + + This includes hidden cut planes. + + For now, there is always one cut plane. + """ + return (self._cutPlane,) + + # Selection + + def setSelectedRegion(self, zrange=None, yrange=None, xrange_=None): + """Set the 3D selected region aligned with the axes. + + Provided range are array indices range. + The provided ranges are clipped to the data. + If a range is None, the range of the array on this dimension is used. + + :param zrange: (zmin, zmax) range of the selection + :param yrange: (ymin, ymax) range of the selection + :param xrange_: (xmin, xmax) range of the selection + """ + # No range given: unset selection + if zrange is None and yrange is None and xrange_ is None: + selectedRange = None + + else: + # Handle default ranges + if self._data is not None: + if zrange is None: + zrange = 0, self._data.shape[0] + if yrange is None: + yrange = 0, self._data.shape[1] + if xrange_ is None: + xrange_ = 0, self._data.shape[2] + + elif None in (xrange_, yrange, zrange): + # One of the range is None and no data available + raise RuntimeError( + 'Data is not set, cannot get default range from it.') + + # Clip selected region to data shape and make sure min <= max + selectedRange = numpy.array(( + (max(0, min(*zrange)), + min(self._data.shape[0], max(*zrange))), + (max(0, min(*yrange)), + min(self._data.shape[1], max(*yrange))), + (max(0, min(*xrange_)), + min(self._data.shape[2], max(*xrange_))), + ), dtype=numpy.int) + + # numpy.equal supports None + if not numpy.all(numpy.equal(selectedRange, self._selectedRange)): + self._selectedRange = selectedRange + + # Update scene accordingly + if self._selectedRange is None: + self._selectionBox.visible = False + else: + self._selectionBox.visible = True + scales = self._selectedRange[:, 1] - self._selectedRange[:, 0] + self._selectionBox.size = scales[::-1] + self._selectionBox.transforms = [ + transform.Translate(*self._selectedRange[::-1, 0])] + + self.sigSelectedRegionChanged.emit(self.getSelectedRegion()) + + def getSelectedRegion(self): + """Returns the currently selected region or None.""" + if self._selectedRange is None: + return None + else: + return SelectedRegion(self._selectedRange, + translation=self.getTranslation(), + scale=self.getScale()) + + # Handle iso-surfaces + + sigIsosurfaceAdded = qt.Signal(object) + """Signal emitted when a new iso-surface is added to the view. + + The newly added iso-surface is provided by this signal + """ + + sigIsosurfaceRemoved = qt.Signal(object) + """Signal emitted when an iso-surface is removed from the view + + The removed iso-surface is provided by this signal. + """ + + def addIsosurface(self, level, color): + """Add an iso-surface to the view. + + :param level: + The value at which to build the iso-surface or a callable + (e.g., a function) taking a 3D numpy.ndarray as input and + returning a float. + Example: numpy.mean(data) + numpy.std(data) + :type level: float or callable + :param color: RGBA color of the isosurface + :type color: str or array-like of 4 float in [0., 1.] + :return: Isosurface object describing this isosurface + """ + isosurface = Isosurface(parent=self) + isosurface.setColor(color) + if callable(level): + isosurface.setAutoLevelFunction(level) + else: + isosurface.setLevel(level) + isosurface._setData(self._data, copy=False) + isosurface.sigLevelChanged.connect(self._updateIsosurfaces) + + self._isosurfaces.append(isosurface) + + self._updateIsosurfaces() + + self.sigIsosurfaceAdded.emit(isosurface) + return isosurface + + def getIsosurfaces(self): + """Return an iterable of all iso-surfaces of the view""" + return tuple(self._isosurfaces) + + def removeIsosurface(self, isosurface): + """Remove an iso-surface from the view. + + :param isosurface: The isosurface object to remove""" + if isosurface not in self.getIsosurfaces(): + _logger.warning( + "Try to remove isosurface that is not in the list: %s", + str(isosurface)) + else: + isosurface.sigLevelChanged.disconnect(self._updateIsosurfaces) + self._isosurfaces.remove(isosurface) + self._updateIsosurfaces() + self.sigIsosurfaceRemoved.emit(isosurface) + + def clearIsosurfaces(self): + """Remove all iso-surfaces from the view.""" + for isosurface in self.getIsosurfaces(): + self.removeIsosurface(isosurface) + + def _updateIsosurfaces(self, level=None): + """Handle updates of iso-surfaces level and add/remove""" + # Sorting using minus, this supposes data 'object' to be max values + sortedIso = sorted(self.getIsosurfaces(), + key=lambda iso: - iso.getLevel()) + self._isogroup.children = [iso._get3DPrimitive() for iso in sortedIso] diff --git a/silx/gui/plot3d/ViewpointToolBar.py b/silx/gui/plot3d/ViewpointToolBar.py new file mode 100644 index 0000000..d062c1b --- /dev/null +++ b/silx/gui/plot3d/ViewpointToolBar.py @@ -0,0 +1,114 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a toolbar to control Plot3DWidget viewpoint.""" + +from __future__ import absolute_import + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "15/09/2016" + + +from silx.gui import qt +from silx.gui.icons import getQIcon + + +class ViewpointActionGroup(qt.QActionGroup): + """ActionGroup of actions to reset the viewpoint. + + As for QActionGroup, add group's actions to the widget with: + `widget.addActions(actionGroup.actions())` + + :param Plot3DWidget plot3D: The widget for which to control the viewpoint + :param parent: See :class:`QActionGroup` + """ + + # Action information: icon name, text, tooltip + _RESET_CAMERA_ACTIONS = ( + ('cube-front', 'Front', 'View along the -Z axis'), + ('cube-back', 'Back', 'View along the +Z axis'), + ('cube-top', 'Top', 'View along the -Y'), + ('cube-bottom', 'Bottom', 'View along the +Y'), + ('cube-right', 'Right', 'View along the -X'), + ('cube-left', 'Left', 'View along the +X'), + ('cube', 'Side', 'Side view') + ) + + def __init__(self, plot3D, parent=None): + super(ViewpointActionGroup, self).__init__(parent) + self.setExclusive(False) + + self._plot3D = plot3D + + for actionInfo in self._RESET_CAMERA_ACTIONS: + iconname, text, tooltip = actionInfo + + action = qt.QAction(getQIcon(iconname), text, None) + action.setCheckable(False) + action.setToolTip(tooltip) + self.addAction(action) + + self.triggered[qt.QAction].connect(self._actionGroupTriggered) + + def _actionGroupTriggered(self, action): + actionname = action.text().lower() + + self._plot3D.viewport.camera.extrinsic.reset(face=actionname) + self._plot3D.centerScene() + + +class ViewpointToolBar(qt.QToolBar): + """A toolbar providing icons to reset the viewpoint. + + :param parent: See :class:`QToolBar` + :param Plot3DWidget plot3D: The widget to control + :param str title: Title of the toolbar + """ + + def __init__(self, parent=None, plot3D=None, title='Viewpoint control'): + super(ViewpointToolBar, self).__init__(title, parent) + + self._actionGroup = ViewpointActionGroup(plot3D) + assert plot3D is not None + self._plot3D = plot3D + self.addActions(self._actionGroup.actions()) + + # Choosing projection disabled for now + # Add projection combo box + # comboBoxProjection = qt.QComboBox() + # comboBoxProjection.addItem('Perspective') + # comboBoxProjection.addItem('Parallel') + # comboBoxProjection.setToolTip( + # 'Choose the projection:' + # ' perspective or parallel (i.e., orthographic)') + # comboBoxProjection.currentIndexChanged[(str)].connect( + # self._comboBoxProjectionCurrentIndexChanged) + # self.addWidget(qt.QLabel('Projection:')) + # self.addWidget(comboBoxProjection) + + # def _comboBoxProjectionCurrentIndexChanged(self, text): + # """Projection combo box listener""" + # self._plot3D.setProjection( + # 'perspective' if text == 'Perspective' else 'orthographic') diff --git a/silx/gui/plot3d/__init__.py b/silx/gui/plot3d/__init__.py new file mode 100644 index 0000000..ad45424 --- /dev/null +++ b/silx/gui/plot3d/__init__.py @@ -0,0 +1,45 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +This package provides widgets displaying 3D content based on OpenGL. + +It depends on PyOpenGL and QtOpenGL. +""" +from __future__ import absolute_import + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "18/01/2017" + + +from .. import qt as _qt + +if not _qt.HAS_OPENGL: + raise ImportError('Qt.QtOpenGL is not available') + +try: + import OpenGL as _OpenGL +except ImportError: + raise ImportError('PyOpenGL is not installed') diff --git a/silx/gui/plot3d/scene/__init__.py b/silx/gui/plot3d/scene/__init__.py new file mode 100644 index 0000000..25a7171 --- /dev/null +++ b/silx/gui/plot3d/scene/__init__.py @@ -0,0 +1,34 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a 3D graphics scene graph structure.""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "08/11/2016" + + +from .core import Elem, Group, PrivateGroup # noqa +from .viewport import Viewport # noqa +from .window import Window # noqa diff --git a/silx/gui/plot3d/scene/axes.py b/silx/gui/plot3d/scene/axes.py new file mode 100644 index 0000000..528e4f7 --- /dev/null +++ b/silx/gui/plot3d/scene/axes.py @@ -0,0 +1,224 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Primitive displaying a text field in the scene.""" + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "17/10/2016" + + +import logging +import numpy + +from ...plot._utils import ticklayout + +from . import core, primitives, text, transform + + +_logger = logging.getLogger(__name__) + + +class LabelledAxes(primitives.GroupBBox): + """A group displaying a bounding box with axes labels around its children. + """ + + def __init__(self): + super(LabelledAxes, self).__init__() + self._ticksForBounds = None + + self._font = text.Font() + + # TODO offset labels from anchor in pixels + + self._xlabel = text.Text2D(font=self._font) + self._xlabel.align = 'center' + self._xlabel.transforms = [self._boxTransforms, + transform.Translate(tx=0.5)] + self._children.insert(-1, self._xlabel) + + self._ylabel = text.Text2D(font=self._font) + self._ylabel.align = 'center' + self._ylabel.transforms = [self._boxTransforms, + transform.Translate(ty=0.5)] + self._children.insert(-1, self._ylabel) + + self._zlabel = text.Text2D(font=self._font) + self._zlabel.align = 'center' + self._zlabel.transforms = [self._boxTransforms, + transform.Translate(tz=0.5)] + self._children.insert(-1, self._zlabel) + + # Init tick lines with dummy pos + self._tickLines = primitives.DashedLines( + positions=((0., 0., 0.), (0., 0., 0.))) + self._tickLines.dash = 5, 10 + self._tickLines.visible = False + self._children.insert(-1, self._tickLines) + + self._tickLabels = core.Group() + self._children.insert(-1, self._tickLabels) + + # Sync color + self.tickColor = 1., 1., 1., 1. + + @property + def tickColor(self): + """Color of ticks and text labels. + + This does NOT set bounding box color. + Use :attr:`color` for the bounding box. + """ + return self._xlabel.foreground + + @tickColor.setter + def tickColor(self, color): + self._xlabel.foreground = color + self._ylabel.foreground = color + self._zlabel.foreground = color + transparentColor = color[0], color[1], color[2], color[3] * 0.6 + self._tickLines.setAttribute('color', transparentColor) + for label in self._tickLabels.children: + label.foreground = color + + @property + def font(self): + """Font of axes text labels (Font)""" + return self._font + + @font.setter + def font(self, font): + self._font = font + self._xlabel.font = font + self._ylabel.font = font + self._zlabel.font = font + for label in self._tickLabels.children: + label.font = font + + @property + def xlabel(self): + """Text label of the X axis (str)""" + return self._xlabel.text + + @xlabel.setter + def xlabel(self, text): + self._xlabel.text = text + + @property + def ylabel(self): + """Text label of the Y axis (str)""" + return self._ylabel.text + + @ylabel.setter + def ylabel(self, text): + self._ylabel.text = text + + @property + def zlabel(self): + """Text label of the Z axis (str)""" + return self._zlabel.text + + @zlabel.setter + def zlabel(self, text): + self._zlabel.text = text + + def _updateTicks(self): + """Check if ticks need update and update them if needed.""" + bounds = self._group.bounds(transformed=False, dataBounds=True) + if bounds is None: # No content + if self._ticksForBounds is not None: + self._ticksForBounds = None + self._tickLines.visible = False + self._tickLabels.children = [] # Reset previous labels + + elif (self._ticksForBounds is None or + not numpy.all(numpy.equal(bounds, self._ticksForBounds))): + self._ticksForBounds = bounds + + # Update ticks + ticklength = numpy.abs(bounds[1] - bounds[0]) + + xticks, xlabels = ticklayout.ticks(*bounds[:, 0]) + yticks, ylabels = ticklayout.ticks(*bounds[:, 1]) + zticks, zlabels = ticklayout.ticks(*bounds[:, 2]) + + # Update tick lines + coords = numpy.empty( + ((len(xticks) + len(yticks) + len(zticks)), 4, 3), + dtype=numpy.float32) + coords[:, :, :] = bounds[0, :] # account for offset from origin + + xcoords = coords[:len(xticks)] + xcoords[:, :, 0] = numpy.asarray(xticks)[:, numpy.newaxis] + xcoords[:, 1, 1] += ticklength[1] # X ticks on XY plane + xcoords[:, 3, 2] += ticklength[2] # X ticks on XZ plane + + ycoords = coords[len(xticks):len(xticks) + len(yticks)] + ycoords[:, :, 1] = numpy.asarray(yticks)[:, numpy.newaxis] + ycoords[:, 1, 0] += ticklength[0] # Y ticks on XY plane + ycoords[:, 3, 2] += ticklength[2] # Y ticks on YZ plane + + zcoords = coords[len(xticks) + len(yticks):] + zcoords[:, :, 2] = numpy.asarray(zticks)[:, numpy.newaxis] + zcoords[:, 1, 0] += ticklength[0] # Z ticks on XZ plane + zcoords[:, 3, 1] += ticklength[1] # Z ticks on YZ plane + + self._tickLines.setPositions(coords.reshape(-1, 3)) + self._tickLines.visible = True + + # Update labels + color = self.tickColor + offsets = bounds[0] - ticklength / 20. + labels = [] + for tick, label in zip(xticks, xlabels): + text2d = text.Text2D(text=label, font=self.font) + text2d.align = 'center' + text2d.foreground = color + text2d.transforms = [transform.Translate( + tx=tick, ty=offsets[1], tz=offsets[2])] + labels.append(text2d) + + for tick, label in zip(yticks, ylabels): + text2d = text.Text2D(text=label, font=self.font) + text2d.align = 'center' + text2d.foreground = color + text2d.transforms = [transform.Translate( + tx=offsets[0], ty=tick, tz=offsets[2])] + labels.append(text2d) + + for tick, label in zip(zticks, zlabels): + text2d = text.Text2D(text=label, font=self.font) + text2d.align = 'center' + text2d.foreground = color + text2d.transforms = [transform.Translate( + tx=offsets[0], ty=offsets[1], tz=tick)] + labels.append(text2d) + + self._tickLabels.children = labels # Reset previous labels + + def prepareGL2(self, context): + self._updateTicks() + super(LabelledAxes, self).prepareGL2(context) diff --git a/silx/gui/plot3d/scene/camera.py b/silx/gui/plot3d/scene/camera.py new file mode 100644 index 0000000..8cc279d --- /dev/null +++ b/silx/gui/plot3d/scene/camera.py @@ -0,0 +1,350 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides classes to handle a perspective projection in 3D.""" + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "25/07/2016" + + +import numpy + +from . import transform + + +# CameraExtrinsic ############################################################# + +class CameraExtrinsic(transform.Transform): + """Transform matrix to handle camera position and orientation. + + :param position: Coordinates of the point of view. + :type position: numpy.ndarray-like of 3 float32. + :param direction: Sight direction vector. + :type direction: numpy.ndarray-like of 3 float32. + :param up: Vector pointing upward in the image plane. + :type up: numpy.ndarray-like of 3 float32. + """ + + def __init__(self, position=(0., 0., 0.), + direction=(0., 0., -1.), + up=(0., 1., 0.)): + + super(CameraExtrinsic, self).__init__() + self._position = None + self.position = position # set _position + self._side = 1., 0., 0. + self._up = 0., 1., 0. + self._direction = 0., 0., -1. + self.setOrientation(direction=direction, up=up) # set _direction, _up + + def _makeMatrix(self): + return transform.mat4LookAtDir(self._position, + self._direction, self._up) + + def copy(self): + """Return an independent copy""" + return CameraExtrinsic(self.position, self.direction, self.up) + + def setOrientation(self, direction=None, up=None): + """Set the rotation of the point of view. + + :param direction: Sight direction vector or + None to keep the current one. + :type direction: numpy.ndarray-like of 3 float32 or None. + :param up: Vector pointing upward in the image plane or + None to keep the current one. + :type up: numpy.ndarray-like of 3 float32 or None. + :raises RuntimeError: if the direction and up are parallel. + """ + if direction is None: # Use current direction + direction = self.direction + else: + assert len(direction) == 3 + direction = numpy.array(direction, copy=True, dtype=numpy.float32) + direction /= numpy.linalg.norm(direction) + + if up is None: # Use current up + up = self.up + else: + assert len(up) == 3 + up = numpy.array(up, copy=True, dtype=numpy.float32) + + # Update side and up to make sure they are perpendicular and normalized + side = numpy.cross(direction, up) + sidenormal = numpy.linalg.norm(side) + if sidenormal == 0.: + raise RuntimeError('direction and up vectors are parallel.') + # Alternative: when one of the input parameter is None, it is + # possible to guess correct vectors using previous direction and up + side /= sidenormal + up = numpy.cross(side, direction) + up /= numpy.linalg.norm(up) + + self._side = side + self._up = up + self._direction = direction + self.notify() + + @property + def position(self): + """Coordinates of the point of view as a numpy.ndarray of 3 float32.""" + return self._position.copy() + + @position.setter + def position(self, position): + assert len(position) == 3 + self._position = numpy.array(position, copy=True, dtype=numpy.float32) + self.notify() + + @property + def direction(self): + """Sight direction (ndarray of 3 float32).""" + return self._direction.copy() + + @direction.setter + def direction(self, direction): + self.setOrientation(direction=direction) + + @property + def up(self): + """Vector pointing upward in the image plane (ndarray of 3 float32). + """ + return self._up.copy() + + @up.setter + def up(self, up): + self.setOrientation(up=up) + + @property + def side(self): + """Vector pointing towards the side of the image plane. + + ndarray of 3 float32""" + return self._side.copy() + + def move(self, direction, step=1.): + """Move the camera relative to the image plane. + + :param str direction: Direction relative to image plane. + One of: 'up', 'down', 'left', 'right', + 'forward', 'backward'. + :param float step: The step of the pan to perform in the coordinate + in which the camera position is defined. + """ + if direction in ('up', 'down'): + vector = self.up * (1. if direction == 'up' else -1.) + elif direction in ('left', 'right'): + vector = self.side * (1. if direction == 'right' else -1.) + elif direction in ('forward', 'backward'): + vector = self.direction * (1. if direction == 'forward' else -1.) + else: + raise ValueError('Unsupported direction: %s' % direction) + + self.position += step * vector + + def rotate(self, direction, angle=1.): + """First-person rotation of the camera towards the direction. + + :param str direction: Direction of movement relative to image plane. + In: 'up', 'down', 'left', 'right'. + :param float angle: The angle in degrees of the rotation. + """ + if direction in ('up', 'down'): + axis = self.side * (1. if direction == 'up' else -1.) + elif direction in ('left', 'right'): + axis = self.up * (1. if direction == 'left' else -1.) + else: + raise ValueError('Unsupported direction: %s' % direction) + + matrix = transform.mat4RotateFromAngleAxis(numpy.radians(angle), *axis) + newdir = numpy.dot(matrix[:3, :3], self.direction) + + if direction in ('up', 'down'): + # Rotate up to avoid up and new direction to be (almost) co-linear + newup = numpy.dot(matrix[:3, :3], self.up) + self.setOrientation(newdir, newup) + else: + # No need to rotate up here as it is the rotation axis + self.direction = newdir + + def orbit(self, direction, center=(0., 0., 0.), angle=1.): + """Rotate the camera around a point. + + :param str direction: Direction of movement relative to image plane. + In: 'up', 'down', 'left', 'right'. + :param center: Position around which to rotate the point of view. + :type center: numpy.ndarray-like of 3 float32. + :param float angle: he angle in degrees of the rotation. + """ + if direction in ('up', 'down'): + axis = self.side * (1. if direction == 'down' else -1.) + elif direction in ('left', 'right'): + axis = self.up * (1. if direction == 'right' else -1.) + else: + raise ValueError('Unsupported direction: %s' % direction) + + # Rotate viewing direction + rotmatrix = transform.mat4RotateFromAngleAxis( + numpy.radians(angle), *axis) + self.direction = numpy.dot(rotmatrix[:3, :3], self.direction) + + # Rotate position around center + center = numpy.array(center, copy=False, dtype=numpy.float32) + matrix = numpy.dot(transform.mat4Translate(*center), rotmatrix) + matrix = numpy.dot(matrix, transform.mat4Translate(*(-center))) + position = numpy.append(self.position, 1.) + self.position = numpy.dot(matrix, position)[:3] + + _RESET_CAMERA_ORIENTATIONS = { + 'side': ((-1., -1., -1.), (0., 1., 0.)), + 'front': ((0., 0., -1.), (0., 1., 0.)), + 'back': ((0., 0., 1.), (0., 1., 0.)), + 'top': ((0., -1., 0.), (0., 0., -1.)), + 'bottom': ((0., 1., 0.), (0., 0., 1.)), + 'right': ((-1., 0., 0.), (0., 1., 0.)), + 'left': ((1., 0., 0.), (0., 1., 0.)) + } + + def reset(self, face=None): + """Reset the camera position to pre-defined orientations. + + :param str face: The direction of the camera in: + side, front, back, top, bottom, right, left. + """ + if face not in self._RESET_CAMERA_ORIENTATIONS: + raise ValueError('Unsupported face: %s' % face) + + distance = numpy.linalg.norm(self.position) + direction, up = self._RESET_CAMERA_ORIENTATIONS[face] + self.setOrientation(direction, up) + self.position = - self.direction * distance + + +class Camera(transform.Transform): + """Combination of camera projection and position. + + See :class:`Perspective` and :class:`CameraExtrinsic`. + + :param float fovy: Vertical field-of-view in degrees. + :param float near: The near clipping plane Z coord (strictly positive). + :param float far: The far clipping plane Z coord (> near). + :param size: Viewport's size used to compute the aspect ratio. + :type size: 2-tuple of float (width, height). + :param position: Coordinates of the point of view. + :type position: numpy.ndarray-like of 3 float32. + :param direction: Sight direction vector. + :type direction: numpy.ndarray-like of 3 float32. + :param up: Vector pointing upward in the image plane. + :type up: numpy.ndarray-like of 3 float32. + """ + + def __init__(self, fovy=30., near=0.1, far=1., size=(1., 1.), + position=(0., 0., 0.), + direction=(0., 0., -1.), up=(0., 1., 0.)): + super(Camera, self).__init__() + self._intrinsic = transform.Perspective(fovy, near, far, size) + self._intrinsic.addListener(self._transformChanged) + self._extrinsic = CameraExtrinsic(position, direction, up) + self._extrinsic.addListener(self._transformChanged) + + def _makeMatrix(self): + return numpy.dot(self.intrinsic.matrix, self.extrinsic.matrix) + + def _transformChanged(self, source): + """Listener of intrinsic and extrinsic camera parameters instances.""" + if source is not self: + self.notify() + + def resetCamera(self, bounds): + """Change camera to have the bounds in the viewing frustum. + + It updates the camera position and depth extent. + Camera sight direction and up are not affected. + + :param bounds: The axes-aligned bounds to include. + :type bounds: numpy.ndarray: ((xMin, yMin, zMin), (xMax, yMax, zMax)) + """ + + center = 0.5 * (bounds[0] + bounds[1]) + radius = numpy.linalg.norm(0.5 * (bounds[1] - bounds[0])) + + if isinstance(self.intrinsic, transform.Perspective): + # Get the viewpoint distance from the bounds center + minfov = numpy.radians(self.intrinsic.fovy) + width, height = self.intrinsic.size + if width < height: + minfov *= width / height + + offset = radius / numpy.sin(0.5 * minfov) + + # Update camera + self.extrinsic.position = \ + center - offset * self.extrinsic.direction + self.intrinsic.setDepthExtent(offset - radius, offset + radius) + + elif isinstance(self.intrinsic, transform.Orthographic): + # Y goes up + self.intrinsic.setClipping( + left=center[0] - radius, + right=center[0] + radius, + bottom=center[1] - radius, + top=center[1] + radius) + + # Update camera + self.extrinsic.position = 0, 0, 0 + self.intrinsic.setDepthExtent(center[2] - radius, + center[2] + radius) + else: + raise RuntimeError('Unsupported camera: %s' % self.intrinsic) + + @property + def intrinsic(self): + """Intrinsic camera parameters, i.e., projection matrix.""" + return self._intrinsic + + @intrinsic.setter + def intrinsic(self, intrinsic): + self._intrinsic.removeListener(self._transformChanged) + self._intrinsic = intrinsic + self._intrinsic.addListener(self._transformChanged) + + @property + def extrinsic(self): + """Extrinsic camera parameters, i.e., position and orientation.""" + return self._extrinsic + + def move(self, *args, **kwargs): + """See :meth:`CameraExtrinsic.move`.""" + self.extrinsic.move(*args, **kwargs) + + def rotate(self, *args, **kwargs): + """See :meth:`CameraExtrinsic.rotate`.""" + self.extrinsic.rotate(*args, **kwargs) + + def orbit(self, *args, **kwargs): + """See :meth:`CameraExtrinsic.orbit`.""" + self.extrinsic.orbit(*args, **kwargs) diff --git a/silx/gui/plot3d/scene/core.py b/silx/gui/plot3d/scene/core.py new file mode 100644 index 0000000..a293f28 --- /dev/null +++ b/silx/gui/plot3d/scene/core.py @@ -0,0 +1,334 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides the base scene structure. + +This module provides the classes for describing a tree structure with +rendering and picking API. +All nodes inherit from :class:`Base`. +Nodes with children are provided with :class:`PrivateGroup` and +:class:`Group` classes. +Leaf rendering nodes should inherit from :class:`Elem`. +""" + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "25/07/2016" + + +import itertools +import weakref + +import numpy + +from . import event +from . import transform + +from .viewport import Viewport + + +# Nodes ####################################################################### + +class Base(event.Notifier): + """A scene node with common features.""" + + def __init__(self): + super(Base, self).__init__() + self._visible = True + self._pickable = False + + self._parentRef = None + + self._transforms = transform.TransformList() + self._transforms.addListener(self._transformChanged) + + # notifying properties + + visible = event.notifyProperty('_visible', + doc="Visibility flag of the node") + pickable = event.notifyProperty('_pickable', + doc="True to make node pickable") + + # Access to tree path + + @property + def parent(self): + """Parent or None if no parent""" + return None if self._parentRef is None else self._parentRef() + + def _setParent(self, parent): + """Set the parent of this node. + + For internal use. + + :param Base parent: The parent. + """ + if parent is not None and self._parentRef is not None: + raise RuntimeError('Trying to add a node at two places.') + # Alternative: remove it from previous children list + self._parentRef = None if parent is None else weakref.ref(parent) + + @property + def path(self): + """Tuple of scene nodes, from the tip of the tree down to this node. + + If this tree is attached to a :class:`Viewport`, + then the :class:`Viewport` is the first element of path. + """ + if self.parent is None: + return self, + elif isinstance(self.parent, Viewport): + return self.parent, self + else: + return self.parent.path + (self, ) + + @property + def viewport(self): + """The viewport this node is attached to or None.""" + root = self.path[0] + return root if isinstance(root, Viewport) else None + + @property + def objectToNDCTransform(self): + """Transform from object to normalized device coordinates. + + Do not forget perspective divide. + """ + # Using the Viewport's transforms property to proxy the camera + path = self.path + assert isinstance(path[0], Viewport) + return transform.StaticTransformList(elem.transforms for elem in path) + + @property + def objectToSceneTransform(self): + """Transform from object to scene. + + Combine transforms up to the Viewport (not including it). + """ + path = self.path + if isinstance(path[0], Viewport): + path = path[1:] # Remove viewport to remove camera transforms + return transform.StaticTransformList(elem.transforms for elem in path) + + # transform + + @property + def transforms(self): + """List of transforms defining the frame of this node relative + to its parent.""" + return self._transforms + + @transforms.setter + def transforms(self, iterable): + self._transforms.removeListener(self._transformChanged) + if isinstance(iterable, transform.TransformList): + # If it is a TransformList, do not create one to enable sharing. + self._transforms = iterable + else: + assert hasattr(iterable, '__iter__') + self._transforms = transform.TransformList(iterable) + self._transforms.addListener(self._transformChanged) + + def _transformChanged(self, source): + self.notify() # Broadcast transform notification + + # Bounds + + _CUBE_CORNERS = numpy.array(list(itertools.product((0., 1.), repeat=3)), + dtype=numpy.float32) + """Unit cube corners used to transform bounds""" + + def _bounds(self, dataBounds=False): + """Override in subclass to provide bounds in object coordinates""" + return None + + def bounds(self, transformed=False, dataBounds=False): + """Returns the bounds of this node aligned with the axis, + with or without transform applied. + + :param bool transformed: False to give bounds in object coordinates + (the default), True to apply this object's + transforms. + :param bool dataBounds: False to give bounds of vertices (the default), + True to give bounds of the represented data. + :return: The bounds: ((xMin, yMin, zMin), (xMax, yMax, zMax)) or None + if no bounds. + :rtype: numpy.ndarray of float + """ + bounds = self._bounds(dataBounds) + + if transformed and bounds is not None: + bounds = self.transforms.transformBounds(bounds) + + return bounds + + # Rendering + + def prepareGL2(self, ctx): + """Called before the rendering to prepare OpenGL resources. + + Override in subclass. + """ + pass + + def renderGL2(self, ctx): + """Called to perform the OpenGL rendering. + + Override in subclass. + """ + pass + + def render(self, ctx): + """Called internally to perform rendering.""" + if self.visible: + ctx.pushTransform(self.transforms) + self.prepareGL2(ctx) + self.renderGL2(ctx) + ctx.popTransform() + + def postRender(self, ctx): + """Hook called when parent's node render is finished. + + Called in the reverse of rendering order (i.e., last child first). + + Meant for nodes that modify the :class:`RenderContext` ctx to + reset their modifications. + """ + pass + + def pick(self, ctx, x, y, depth=None): + """True/False picking, should be fast""" + if self.pickable: + pass + + def pickRay(self, ctx, ray): + """Picking returning list of ray intersections.""" + if self.pickable: + pass + + +class Elem(Base): + """A scene node that does some rendering.""" + + def __init__(self): + super(Elem, self).__init__() + # self.showBBox = False # Here or outside scene graph? + # self.clipPlane = None # This needs to be handled in the shader + + +class PrivateGroup(Base): + """A scene node that renders its (private) childern. + + :param iterable children: :class:`Base` nodes to add as children + """ + + class ChildrenList(event.NotifierList): + """List of children with notification and children's parent update.""" + + def _listWillChangeHook(self, methodName, *args, **kwargs): + super(PrivateGroup.ChildrenList, self)._listWillChangeHook( + methodName, *args, **kwargs) + for item in self: + item._setParent(None) + + def _listWasChangedHook(self, methodName, *args, **kwargs): + for item in self: + item._setParent(self._parentRef()) + super(PrivateGroup.ChildrenList, self)._listWasChangedHook( + methodName, *args, **kwargs) + + def __init__(self, parent, children): + self._parentRef = weakref.ref(parent) + super(PrivateGroup.ChildrenList, self).__init__(children) + + def __init__(self, children=()): + super(PrivateGroup, self).__init__() + self.__children = PrivateGroup.ChildrenList(self, children) + self.__children.addListener(self._updated) + + @property + def _children(self): + """List of children to be rendered. + + This private attribute is meant to be used by subclass. + """ + return self.__children + + @_children.setter + def _children(self, iterable): + self.__children.removeListener(self._updated) + for item in self.__children: + item._setParent(None) + del self.__children # This is needed + self.__children = PrivateGroup.ChildrenList(self, iterable) + self.__children.addListener(self._updated) + self.notify() + + def _updated(self, source, *args, **kwargs): + """Listen for updates""" + if source is not self: # Avoid infinite recursion + self.notify(*args, **kwargs) + + def _bounds(self, dataBounds=False): + """Compute the bounds from transformed children bounds""" + bounds = [] + for child in self._children: + if child.visible: + childBounds = child.bounds( + transformed=True, dataBounds=dataBounds) + if childBounds is not None: + bounds.append(childBounds) + + if len(bounds) == 0: + return None + else: + bounds = numpy.array(bounds, dtype=numpy.float32) + return numpy.array((bounds[:, 0, :].min(axis=0), + bounds[:, 1, :].max(axis=0)), + dtype=numpy.float32) + + def prepareGL2(self, ctx): + pass + + def renderGL2(self, ctx): + """Render all children""" + for child in self._children: + child.render(ctx) + for child in reversed(self._children): + child.postRender(ctx) + + +class Group(PrivateGroup): + """A scene node that renders its (public) children.""" + + @property + def children(self): + """List of children to be rendered.""" + return self._children + + @children.setter + def children(self, iterable): + self._children = iterable diff --git a/silx/gui/plot3d/scene/cutplane.py b/silx/gui/plot3d/scene/cutplane.py new file mode 100644 index 0000000..79b4168 --- /dev/null +++ b/silx/gui/plot3d/scene/cutplane.py @@ -0,0 +1,374 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""A cut plane in a 3D texture: hackish implementation... +""" + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "05/10/2016" + +import string +import numpy + +from ... import _glutils +from ..._glutils import gl + +from .function import Colormap +from .primitives import Box, Geometry, PlaneInGroup +from . import transform, utils + + +class ColormapMesh3D(Geometry): + """A 3D mesh with color from a 3D texture.""" + + _shaders = (""" + attribute vec3 position; + attribute vec3 normal; + + uniform mat4 matrix; + uniform mat4 transformMat; + //uniform mat3 matrixInvTranspose; + uniform vec3 dataScale; + + varying vec4 vCameraPosition; + varying vec3 vPosition; + varying vec3 vNormal; + varying vec3 vTexCoords; + + void main(void) + { + vCameraPosition = transformMat * vec4(position, 1.0); + //vNormal = matrixInvTranspose * normalize(normal); + vPosition = position; + vTexCoords = dataScale * position; + vNormal = normal; + gl_Position = matrix * vec4(position, 1.0); + } + """, + string.Template(""" + varying vec4 vCameraPosition; + varying vec3 vPosition; + varying vec3 vNormal; + varying vec3 vTexCoords; + uniform sampler3D data; + uniform float alpha; + + $colormapDecl + + $clippingDecl + $lightingFunction + + void main(void) + { + float value = texture3D(data, vTexCoords).r; + vec4 color = $colormapCall(value); + color.a = alpha; + + $clippingCall(vCameraPosition); + + gl_FragColor = $lightingCall(color, vPosition, vNormal); + } + """)) + + def __init__(self, position, normal, data, copy=True, + mode='triangles', indices=None, colormap=None): + assert mode in self._TRIANGLE_MODES + data = numpy.array(data, copy=copy, order='C') + assert data.ndim == 3 + self._data = data + self._texture = None + self._update_texture = True + self._update_texture_filter = False + self._alpha = 1. + self._colormap = colormap or Colormap() # Default colormap + self._colormap.addListener(self._cmapChanged) + self._interpolation = 'linear' + super(ColormapMesh3D, self).__init__(mode, + indices, + position=position, + normal=normal) + + self.isBackfaceVisible = True + + def setData(self, data, copy=True): + data = numpy.array(data, copy=copy, order='C') + assert data.ndim == 3 + self._data = data + self._update_texture = True + + def getData(self, copy=True): + return numpy.array(self._data, copy=copy) + + @property + def interpolation(self): + """The texture interpolation mode: 'linear' or 'nearest'""" + return self._interpolation + + @interpolation.setter + def interpolation(self, interpolation): + assert interpolation in ('linear', 'nearest') + self._interpolation = interpolation + self._update_texture_filter = True + self.notify() + + @property + def alpha(self): + """Transparency of the plane, float in [0, 1]""" + return self._alpha + + @alpha.setter + def alpha(self, alpha): + self._alpha = float(alpha) + + @property + def colormap(self): + """The colormap used by this primitive""" + return self._colormap + + def _cmapChanged(self, source, *args, **kwargs): + """Broadcast colormap changes""" + self.notify(*args, **kwargs) + + def prepareGL2(self, ctx): + if self._texture is None or self._update_texture: + if self._texture is not None: + self._texture.discard() + + if self.interpolation == 'nearest': + filter_ = gl.GL_NEAREST + else: + filter_ = gl.GL_LINEAR + self._update_texture = False + self._update_texture_filter = False + self._texture = _glutils.Texture( + gl.GL_R32F, self._data, gl.GL_RED, + minFilter=filter_, + magFilter=filter_, + wrap=gl.GL_CLAMP_TO_EDGE) + + if self._update_texture_filter: + self._update_texture_filter = False + if self.interpolation == 'nearest': + filter_ = gl.GL_NEAREST + else: + filter_ = gl.GL_LINEAR + self._texture.minFilter = filter_ + self._texture.magFilter = filter_ + + super(ColormapMesh3D, self).prepareGL2(ctx) + + def renderGL2(self, ctx): + fragment = self._shaders[1].substitute( + clippingDecl=ctx.clipper.fragDecl, + clippingCall=ctx.clipper.fragCall, + lightingFunction=ctx.viewport.light.fragmentDef, + lightingCall=ctx.viewport.light.fragmentCall, + colormapDecl=self.colormap.decl, + colormapCall=self.colormap.call + ) + program = ctx.glCtx.prog(self._shaders[0], fragment) + program.use() + + ctx.viewport.light.setupProgram(ctx, program) + self.colormap.setupProgram(ctx, program) + + if not self.isBackfaceVisible: + gl.glCullFace(gl.GL_BACK) + gl.glEnable(gl.GL_CULL_FACE) + + program.setUniformMatrix('matrix', ctx.objectToNDC.matrix) + program.setUniformMatrix('transformMat', + ctx.objectToCamera.matrix, + safe=True) + gl.glUniform1f(program.uniforms['alpha'], self._alpha) + + shape = self._data.shape + scales = 1./shape[2], 1./shape[1], 1./shape[0] + gl.glUniform3f(program.uniforms['dataScale'], *scales) + + gl.glUniform1i(program.uniforms['data'], self._texture.texUnit) + + ctx.clipper.setupProgram(ctx, program) + + self._texture.bind() + self._draw(program) + + if not self.isBackfaceVisible: + gl.glDisable(gl.GL_CULL_FACE) + + +class CutPlane(PlaneInGroup): + """A cutting plane in a 3D texture""" + + def __init__(self, point=(0., 0., 0.), normal=(0., 0., 1.)): + self._data = None + self._mesh = None + self._alpha = 1. + self._interpolation = 'linear' + self._colormap = Colormap() + super(CutPlane, self).__init__(point, normal) + + def setData(self, data, copy=True): + if data is None: + self._data = None + if self._mesh is not None: + self._children.remove(self._mesh) + self._mesh = None + + else: + data = numpy.array(data, copy=copy, order='C') + assert data.ndim == 3 + self._data = data + if self._mesh is not None: + self._mesh.setData(data, copy=False) + + def getData(self, copy=True): + return None if self._mesh is None else self._mesh.getData(copy=copy) + + @property + def alpha(self): + return self._alpha + + @alpha.setter + def alpha(self, alpha): + self._alpha = float(alpha) + if self._mesh is not None: + self._mesh.alpha = alpha + + @property + def colormap(self): + return self._colormap + + @property + def interpolation(self): + """The texture interpolation mode: 'linear' (default) or 'nearest'""" + return self._interpolation + + @interpolation.setter + def interpolation(self, interpolation): + assert interpolation in ('nearest', 'linear') + if interpolation != self.interpolation: + self._interpolation = interpolation + if self._mesh is not None: + self._mesh.interpolation = interpolation + + def prepareGL2(self, ctx): + if self.isValid: + + contourVertices = self.contourVertices + + if (self.interpolation == 'nearest' and + contourVertices is not None and len(contourVertices)): + # Avoid cut plane co-linear with array bin edges + for index, normal in enumerate(((1., 0., 0.), (0., 1., 0.), (0., 0., 1.))): + if (numpy.all(numpy.equal(self.plane.normal, normal)) and + int(self.plane.point[index]) == self.plane.point[index]): + contourVertices += self.plane.normal * 0.01 # Add an offset + break + + if self._mesh is None and self._data is not None: + self._mesh = ColormapMesh3D(contourVertices, + normal=self.plane.normal, + data=self._data, + copy=False, + mode='fan', + colormap=self.colormap) + self._mesh.alpha = self._alpha + self._interpolation = self.interpolation + self._children.insert(0, self._mesh) + + if self._mesh is not None: + if (contourVertices is None or + len(contourVertices) == 0): + self._mesh.visible = False + else: + self._mesh.visible = True + self._mesh.setAttribute('normal', self.plane.normal) + self._mesh.setAttribute('position', contourVertices) + + super(CutPlane, self).prepareGL2(ctx) + + def renderGL2(self, ctx): + with self.viewport.light.turnOff(): + super(CutPlane, self).renderGL2(ctx) + + def _bounds(self, dataBounds=False): + if not dataBounds: + vertices = self.contourVertices + if vertices is not None: + return numpy.array( + (vertices.min(axis=0), vertices.max(axis=0)), + dtype=numpy.float32) + else: + return None # Plane in not slicing the data volume + else: + if self._data is None: + return None + else: + depth, height, width = self._data.shape + return numpy.array(((0., 0., 0.), + (width, height, depth)), + dtype=numpy.float32) + + @property + def contourVertices(self): + """The vertices of the contour of the plane/bounds intersection.""" + # TODO copy from PlaneInGroup, refactor all that! + bounds = self.bounds(dataBounds=True) + if bounds is None: + return None # No bounds: no vertices + + # Check if cache is valid and return it + cachebounds, cachevertices = self._cache + if numpy.all(numpy.equal(bounds, cachebounds)): + return cachevertices + + # Cache is not OK, rebuild it + boxvertices = bounds[0] + Box._vertices.copy()*(bounds[1] - bounds[0]) + lineindices = Box._lineIndices + vertices = utils.boxPlaneIntersect( + boxvertices, lineindices, self.plane.normal, self.plane.point) + + self._cache = bounds, vertices if len(vertices) != 0 else None + + return self._cache[1] + + # Render transforms RW, TODO refactor this! + @property + def transforms(self): + return self._transforms + + @transforms.setter + def transforms(self, iterable): + self._transforms.removeListener(self._transformChanged) + if isinstance(iterable, transform.TransformList): + # If it is a TransformList, do not create one to enable sharing. + self._transforms = iterable + else: + assert hasattr(iterable, '__iter__') + self._transforms = transform.TransformList(iterable) + self._transforms.addListener(self._transformChanged) diff --git a/silx/gui/plot3d/scene/event.py b/silx/gui/plot3d/scene/event.py new file mode 100644 index 0000000..7b85434 --- /dev/null +++ b/silx/gui/plot3d/scene/event.py @@ -0,0 +1,225 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a simple generic notification system.""" + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "25/07/2016" + + +import logging + +from silx.utils.weakref import WeakList + +_logger = logging.getLogger(__name__) + + +# Notifier #################################################################### + +class Notifier(object): + """Base class for object with notification mechanism.""" + + def __init__(self): + self._listeners = WeakList() + + def addListener(self, listener): + """Register a listener. + + Adding an already registered listener has no effect. + + :param callable listener: The function or method to register. + """ + if listener not in self._listeners: + self._listeners.append(listener) + else: + _logger.warning('Ignoring addition of an already registered listener') + + def removeListener(self, listener): + """Remove a previously registered listener. + + :param callable listener: The function or method to unregister. + """ + try: + self._listeners.remove(listener) + except ValueError: + _logger.warn('Trying to remove a listener that is not registered') + + def notify(self, *args, **kwargs): + """Notify all registered listeners with the given parameters. + + Listeners are called directly in this method. + Listeners are called in the order they were registered. + """ + for listener in self._listeners: + listener(self, *args, **kwargs) + + +def notifyProperty(attrName, copy=False, converter=None, doc=None): + """Create a property that adds notification to an attribute. + + :param str attrName: The name of the attribute to wrap. + :param bool copy: Whether to return a copy of the attribute + or not (the default). + :param converter: Function converting input value to appropriate type + This function takes a single argument and return the + converted value. + It can be used to perform some asserts. + :param str doc: The docstring of the property + :return: A property with getter and setter + """ + if copy: + def getter(self): + return getattr(self, attrName).copy() + else: + def getter(self): + return getattr(self, attrName) + + if converter is None: + def setter(self, value): + if getattr(self, attrName) != value: + setattr(self, attrName, value) + self.notify() + + else: + def setter(self, value): + value = converter(value) + if getattr(self, attrName) != value: + setattr(self, attrName, value) + self.notify() + + return property(getter, setter, doc=doc) + + +class HookList(list): + """List with hooks before and after modification.""" + + def __init__(self, iterable): + super(HookList, self).__init__(iterable) + + self._listWasChangedHook('__init__', iterable) + + def _listWillChangeHook(self, methodName, *args, **kwargs): + """To override. Called before modifying the list. + + This method is called with the name of the method called to + modify the list and its parameters. + """ + pass + + def _listWasChangedHook(self, methodName, *args, **kwargs): + """To override. Called after modifying the list. + + This method is called with the name of the method called to + modify the list and its parameters. + """ + pass + + # Wrapping methods that modify the list + + def _wrapper(self, methodName, *args, **kwargs): + """Generic wrapper of list methods calling the hooks.""" + self._listWillChangeHook(methodName, *args, **kwargs) + result = getattr(super(HookList, self), + methodName)(*args, **kwargs) + self._listWasChangedHook(methodName, *args, **kwargs) + return result + + # Add methods + + def __iadd__(self, *args, **kwargs): + return self._wrapper('__iadd__', *args, **kwargs) + + def __imul__(self, *args, **kwargs): + return self._wrapper('__imul__', *args, **kwargs) + + def append(self, *args, **kwargs): + return self._wrapper('append', *args, **kwargs) + + def extend(self, *args, **kwargs): + return self._wrapper('extend', *args, **kwargs) + + def insert(self, *args, **kwargs): + return self._wrapper('insert', *args, **kwargs) + + # Remove methods + + def __delitem__(self, *args, **kwargs): + return self._wrapper('__delitem__', *args, **kwargs) + + def __delslice__(self, *args, **kwargs): + return self._wrapper('__delslice__', *args, **kwargs) + + def remove(self, *args, **kwargs): + return self._wrapper('remove', *args, **kwargs) + + def pop(self, *args, **kwargs): + return self._wrapper('pop', *args, **kwargs) + + # Set methods + + def __setitem__(self, *args, **kwargs): + return self._wrapper('__setitem__', *args, **kwargs) + + def __setslice__(self, *args, **kwargs): + return self._wrapper('__setslice__', *args, **kwargs) + + # In place methods + + def sort(self, *args, **kwargs): + return self._wrapper('sort', *args, **kwargs) + + def reverse(self, *args, **kwargs): + return self._wrapper('reverse', *args, **kwargs) + + +class NotifierList(HookList, Notifier): + """List of Notifiers with notification mechanism. + + This class registers itself as a listener of the list items. + + The default listener method forward notification from list items + to the listeners of the list. + """ + + def __init__(self, iterable=()): + Notifier.__init__(self) + HookList.__init__(self, iterable) + + def _listWillChangeHook(self, methodName, *args, **kwargs): + for item in self: + item.removeListener(self._notified) + + def _listWasChangedHook(self, methodName, *args, **kwargs): + for item in self: + item.addListener(self._notified) + self.notify() + + def _notified(self, source, *args, **kwargs): + """Default listener forwarding list item changes to its listeners.""" + # Avoid infinite recursion if the list is listening itself + if source is not self: + self.notify(*args, **kwargs) diff --git a/silx/gui/plot3d/scene/function.py b/silx/gui/plot3d/scene/function.py new file mode 100644 index 0000000..80ac820 --- /dev/null +++ b/silx/gui/plot3d/scene/function.py @@ -0,0 +1,471 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides functions to add to shaders.""" + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "08/11/2016" + + +import contextlib +import logging +import numpy + +from ..._glutils import gl + +from . import event +from . import utils + + +_logger = logging.getLogger(__name__) + + +class ProgramFunction(object): + """Class providing a function to add to a GLProgram shaders. + """ + + def setupProgram(self, context, program): + """Sets-up uniforms of a program using this shader function. + + :param RenderContext context: The current rendering context + :param GLProgram program: The program to set-up. + It MUST be in use and using this function. + """ + pass + + +class ClippingPlane(ProgramFunction): + """Description of a clipping plane and rendering. + + Convention: Clipping is performed in camera/eye space. + + :param point: Local coordinates of a point on the plane. + :type point: numpy.ndarray-like of 3 float32 + :param normal: Local coordinates of the plane normal. + :type normal: numpy.ndarray-like of 3 float32 + """ + + _fragDecl = """ + /* Clipping plane */ + /* as rx + gy + bz + a > 0, clipping all positive */ + uniform vec4 planeEq; + + /* Position is in camera/eye coordinates */ + + bool isClipped(vec4 position) { + vec4 tmp = planeEq * position; + float value = tmp.x + tmp.y + tmp.z + planeEq.a; + return (value < 0.0001); + } + + void clipping(vec4 position) { + if (isClipped(position)) { + discard; + } + } + /* End of clipping */ + """ + + _fragDeclNoop = """ + bool isClipped(vec4 position) + { + return false; + } + + void clipping(vec4 position) {} + """ + + def __init__(self, point=(0., 0., 0.), normal=(0., 0., 0.)): + self._plane = utils.Plane(point, normal) + + @property + def plane(self): + """Plane parameters in camera space.""" + return self._plane + + # GL2 + + @property + def fragDecl(self): + return self._fragDecl if self.plane.isPlane else self._fragDeclNoop + + @property + def fragCall(self): + return "clipping" + + def setupProgram(self, context, program): + """Sets-up uniforms of a program using this shader function. + + :param RenderContext context: The current rendering context + :param GLProgram program: The program to set-up. + It MUST be in use and using this function. + """ + if self.plane.isPlane: + gl.glUniform4f(program.uniforms['planeEq'], *self.plane.parameters) + + +class DirectionalLight(event.Notifier, ProgramFunction): + """Description of a directional Phong light. + + :param direction: The direction of the light or None to disable light + :type direction: ndarray of 3 floats or None + :param ambient: RGB ambient light + :type ambient: ndarray of 3 floats in [0, 1], default: (1., 1., 1.) + :param diffuse: RGB diffuse light parameter + :type diffuse: ndarray of 3 floats in [0, 1], default: (0., 0., 0.) + :param specular: RGB specular light parameter + :type specular: ndarray of 3 floats in [0, 1], default: (1., 1., 1.) + :param int shininess: The shininess of the material for specular term, + default: 0 which disables specular component. + """ + + fragmentShaderFunction = """ + /* Lighting */ + struct DLight { + vec3 lightDir; // Direction of light in object space + vec3 ambient; + vec3 diffuse; + vec3 specular; + float shininess; + vec3 viewPos; // Camera position in object space + }; + + uniform DLight dLight; + + vec4 lighting(vec4 color, vec3 position, vec3 normal) + { + normal = normalize(normal); + // 1-sided + float nDotL = max(0.0, dot(normal, - dLight.lightDir)); + + // 2-sided + //float nDotL = dot(normal, - dLight.lightDir); + //if (nDotL < 0.) { + // nDotL = - nDotL; + // normal = - normal; + //} + + float specFactor = 0.; + if (dLight.shininess > 0. && nDotL > 0.) { + vec3 reflection = reflect(dLight.lightDir, normal); + vec3 viewDir = normalize(dLight.viewPos - position); + specFactor = max(0.0, dot(reflection, viewDir)); + if (specFactor > 0.) { + specFactor = pow(specFactor, dLight.shininess); + } + } + + vec3 enlightedColor = color.rgb * (dLight.ambient + + dLight.diffuse * nDotL) + + dLight.specular * specFactor; + + return vec4(enlightedColor.rgb, color.a); + } + /* End of Lighting */ + """ + + fragmentShaderFunctionNoop = """ + vec4 lighting(vec4 color, vec3 position, vec3 normal) + { + return color; + } + """ + + def __init__(self, direction=None, + ambient=(1., 1., 1.), diffuse=(0., 0., 0.), + specular=(1., 1., 1.), shininess=0): + super(DirectionalLight, self).__init__() + self._direction = None + self.direction = direction # Set _direction + self._isOn = True + self._ambient = ambient + self._diffuse = diffuse + self._specular = specular + self._shininess = shininess + + ambient = event.notifyProperty('_ambient') + diffuse = event.notifyProperty('_diffuse') + specular = event.notifyProperty('_specular') + shininess = event.notifyProperty('_shininess') + + @property + def isOn(self): + """True if light is on, False otherwise.""" + return self._isOn and self._direction is not None + + @isOn.setter + def isOn(self, isOn): + self._isOn = bool(isOn) + + @contextlib.contextmanager + def turnOff(self): + """Context manager to temporary turn off lighting during rendering. + + >>> with light.turnOff(): + ... # Do some rendering without lighting + """ + wason = self._isOn + self._isOn = False + yield + self._isOn = wason + + @property + def direction(self): + """The direction of the light, or None if light is not on.""" + return self._direction + + @direction.setter + def direction(self, direction): + if direction is None: + self._direction = None + else: + assert len(direction) == 3 + direction = numpy.array(direction, dtype=numpy.float32, copy=True) + norm = numpy.linalg.norm(direction) + assert norm != 0 + self._direction = direction / norm + self.notify() + + # GL2 + + @property + def fragmentDef(self): + """Definition to add to fragment shader""" + if self.isOn: + return self.fragmentShaderFunction + else: + return self.fragmentShaderFunctionNoop + + @property + def fragmentCall(self): + """Function name to call in fragment shader""" + return "lighting" + + def setupProgram(self, context, program): + """Sets-up uniforms of a program using this shader function. + + :param RenderContext context: The current rendering context + :param GLProgram program: The program to set-up. + It MUST be in use and using this function. + """ + if self.isOn and self._direction is not None: + # Transform light direction from camera space to object coords + lightdir = context.objectToCamera.transformDir( + self._direction, direct=False) + lightdir /= numpy.linalg.norm(lightdir) + + gl.glUniform3f(program.uniforms['dLight.lightDir'], *lightdir) + + # Convert view position to object coords + viewpos = context.objectToCamera.transformPoint( + numpy.array((0., 0., 0., 1.), dtype=numpy.float32), + direct=False, + perspectiveDivide=True)[:3] + gl.glUniform3f(program.uniforms['dLight.viewPos'], *viewpos) + + gl.glUniform3f(program.uniforms['dLight.ambient'], *self.ambient) + gl.glUniform3f(program.uniforms['dLight.diffuse'], *self.diffuse) + gl.glUniform3f(program.uniforms['dLight.specular'], *self.specular) + gl.glUniform1f(program.uniforms['dLight.shininess'], + self.shininess) + + +class Colormap(event.Notifier, ProgramFunction): + # TODO use colors for out-of-bound values, for <=0 with log, for nan + # TODO texture-based colormap + + decl = """ + #define CMAP_GRAY 0 + #define CMAP_R_GRAY 1 + #define CMAP_RED 2 + #define CMAP_GREEN 3 + #define CMAP_BLUE 4 + #define CMAP_TEMP 5 + + uniform struct { + int id; + bool isLog; + float min; + float oneOverRange; + } cmap; + + const float oneOverLog10 = 0.43429448190325176; + + vec4 colormap(float value) { + if (cmap.isLog) { /* Log10 mapping */ + if (value > 0.0) { + value = clamp(cmap.oneOverRange * + (oneOverLog10 * log(value) - cmap.min), + 0.0, 1.0); + } else { + value = 0.0; + } + } else { /* Linear mapping */ + value = clamp(cmap.oneOverRange * (value - cmap.min), 0.0, 1.0); + } + + if (cmap.id == CMAP_GRAY) { + return vec4(value, value, value, 1.0); + } + else if (cmap.id == CMAP_R_GRAY) { + float invValue = 1.0 - value; + return vec4(invValue, invValue, invValue, 1.0); + } + else if (cmap.id == CMAP_RED) { + return vec4(value, 0.0, 0.0, 1.0); + } + else if (cmap.id == CMAP_GREEN) { + return vec4(0.0, value, 0.0, 1.0); + } + else if (cmap.id == CMAP_BLUE) { + return vec4(0.0, 0.0, value, 1.0); + } + else if (cmap.id == CMAP_TEMP) { + //red: 0.5->0.75: 0->1 + //green: 0.->0.25: 0->1; 0.75->1.: 1->0 + //blue: 0.25->0.5: 1->0 + return vec4( + clamp(4.0 * value - 2.0, 0.0, 1.0), + 1.0 - clamp(4.0 * abs(value - 0.5) - 1.0, 0.0, 1.0), + 1.0 - clamp(4.0 * value - 1.0, 0.0, 1.0), + 1.0); + } + else { + /* Unknown colormap */ + return vec4(0.0, 0.0, 0.0, 1.0); + } + } + """ + + call = "colormap" + + _COLORMAPS = { + 'gray': 0, + 'reversed gray': 1, + 'red': 2, + 'green': 3, + 'blue': 4, + 'temperature': 5 + } + + COLORMAPS = tuple(_COLORMAPS.keys()) + """Tuple of supported colormap names.""" + + NORMS = 'linear', 'log' + """Tuple of supported normalizations.""" + + def __init__(self, name='gray', norm='linear', range_=(1., 10.)): + """Shader function to apply a colormap to a value. + + :param str name: Name of the colormap. + :param str norm: Normalization to apply: 'linear' (default) or 'log'. + :param range_: Range of value to map to the colormap. + :type range_: 2-tuple of float (begin, end). + """ + super(Colormap, self).__init__() + + # Init privates to default + self._name, self._norm, self._range = 'gray', 'linear', (1., 10.) + + # Set to param values through properties to go through asserts + self.name = name + self.norm = norm + self.range_ = range_ + + @property + def name(self): + """Name of the colormap in use.""" + return self._name + + @name.setter + def name(self, name): + if name != self._name: + assert name in self.COLORMAPS + self._name = name + self.notify() + + @property + def norm(self): + """Normalization to use for colormap mapping. + + Either 'linear' (the default) or 'log' for log10 mapping. + With 'log' normalization, values <= 0. are set to 1. (i.e. log == 0) + """ + return self._norm + + @norm.setter + def norm(self, norm): + if norm != self._norm: + assert norm in self.NORMS + self._norm = norm + if norm == 'log': + self.range_ = self.range_ # To test for positive range_ + self.notify() + + @property + def range_(self): + """Range of values to map to the colormap. + + 2-tuple of floats: (begin, end). + The begin value is mapped to the origin of the colormap and the + end value is mapped to the other end of the colormap. + The colormap is reversed if begin > end. + """ + return self._range + + @range_.setter + def range_(self, range_): + assert len(range_) == 2 + range_ = float(range_[0]), float(range_[1]) + + if self.norm == 'log' and (range_[0] <= 0. or range_[1] <= 0.): + _logger.warn( + "Log normalization and negative range: updating range.") + minPos = numpy.finfo(numpy.float32).tiny + range_ = max(range_[0], minPos), max(range_[1], minPos) + + if range_ != self._range: + self._range = range_ + self.notify() + + def setupProgram(self, context, program): + """Sets-up uniforms of a program using this shader function. + + :param RenderContext context: The current rendering context + :param GLProgram program: The program to set-up. + It MUST be in use and using this function. + """ + gl.glUniform1i(program.uniforms['cmap.id'], self._COLORMAPS[self.name]) + gl.glUniform1i(program.uniforms['cmap.isLog'], self._norm == 'log') + + min_, max_ = self.range_ + if self._norm == 'log': + min_, max_ = numpy.log10(min_), numpy.log10(max_) + + gl.glUniform1f(program.uniforms['cmap.min'], min_) + gl.glUniform1f(program.uniforms['cmap.oneOverRange'], + (1. / (max_ - min_)) if max_ != min_ else 0.) diff --git a/silx/gui/plot3d/scene/interaction.py b/silx/gui/plot3d/scene/interaction.py new file mode 100644 index 0000000..68bfc13 --- /dev/null +++ b/silx/gui/plot3d/scene/interaction.py @@ -0,0 +1,652 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides interaction to plug on the scene graph.""" + +from __future__ import absolute_import + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "25/07/2016" + +import logging +import numpy + +from silx.gui.plot.Interaction import \ + StateMachine, State, LEFT_BTN, RIGHT_BTN # , MIDDLE_BTN + +from . import transform + + +_logger = logging.getLogger(__name__) + + +# ClickOrDrag ################################################################# + +# TODO merge with silx.gui.plot.Interaction.ClickOrDrag +class ClickOrDrag(StateMachine): + """Click or drag interaction for a given button.""" + + DRAG_THRESHOLD_SQUARE_DIST = 5 ** 2 + + class Idle(State): + def onPress(self, x, y, btn): + if btn == self.machine.button: + self.goto('clickOrDrag', x, y) + return True + + class ClickOrDrag(State): + def enterState(self, x, y): + self.initPos = x, y + + enter = enterState # silx v.0.3 support, remove when 0.4 out + + def onMove(self, x, y): + dx = (x - self.initPos[0]) ** 2 + dy = (y - self.initPos[1]) ** 2 + if (dx ** 2 + dy ** 2) >= self.machine.DRAG_THRESHOLD_SQUARE_DIST: + self.goto('drag', self.initPos, (x, y)) + + def onRelease(self, x, y, btn): + if btn == self.machine.button: + self.machine.click(x, y) + self.goto('idle') + + class Drag(State): + def enterState(self, initPos, curPos): + self.initPos = initPos + self.machine.beginDrag(*initPos) + self.machine.drag(*curPos) + + enter = enterState # silx v.0.3 support, remove when 0.4 out + + def onMove(self, x, y): + self.machine.drag(x, y) + + def onRelease(self, x, y, btn): + if btn == self.machine.button: + self.machine.endDrag(self.initPos, (x, y)) + self.goto('idle') + + def __init__(self, button=LEFT_BTN): + self.button = button + states = { + 'idle': ClickOrDrag.Idle, + 'clickOrDrag': ClickOrDrag.ClickOrDrag, + 'drag': ClickOrDrag.Drag + } + super(ClickOrDrag, self).__init__(states, 'idle') + + def click(self, x, y): + """Called upon a left or right button click. + To override in a subclass. + """ + pass + + def beginDrag(self, x, y): + """Called at the beginning of a drag gesture with left button + pressed. + To override in a subclass. + """ + pass + + def drag(self, x, y): + """Called on mouse moved during a drag gesture. + To override in a subclass. + """ + pass + + def endDrag(self, x, y): + """Called at the end of a drag gesture when the left button is + released. + To override in a subclass. + """ + pass + + +# CameraRotate ################################################################ + +class CameraRotate(ClickOrDrag): + """Camera rotation using an arcball-like interaction.""" + + def __init__(self, viewport, orbitAroundCenter=True, button=RIGHT_BTN): + self._viewport = viewport + self._orbitAroundCenter = orbitAroundCenter + self._reset() + super(CameraRotate, self).__init__(button) + + def _reset(self): + self._origin, self._center = None, None + self._startExtrinsic = None + + def click(self, x, y): + pass # No interaction yet + + def beginDrag(self, x, y): + centerPos = None + if not self._orbitAroundCenter: + # Try to use picked object position as center of rotation + ndcZ = self._viewport._pickNdcZGL(x, y) + if ndcZ != 1.: + # Hit an object, use picked point as center + centerPos = self._viewport._getXZYGL(x, y) # Can return None + + if centerPos is None: + # Not using picked position, use scene center + bounds = self._viewport.scene.bounds(transformed=True) + centerPos = 0.5 * (bounds[0] + bounds[1]) + + self._center = transform.Translate(*centerPos) + self._origin = x, y + self._startExtrinsic = self._viewport.camera.extrinsic.copy() + + def drag(self, x, y): + if self._center is None: + return + + dx, dy = self._origin[0] - x, self._origin[1] - y + + if dx == 0 and dy == 0: + direction = self._startExtrinsic.direction + up = self._startExtrinsic.up + position = self._startExtrinsic.position + else: + minsize = min(self._viewport.size) + distance = numpy.sqrt(dx ** 2 + dy ** 2) + angle = distance / minsize * numpy.pi + + # Take care of y inversion + direction = dx * self._startExtrinsic.side - \ + dy * self._startExtrinsic.up + direction /= numpy.linalg.norm(direction) + axis = numpy.cross(direction, self._startExtrinsic.direction) + axis /= numpy.linalg.norm(axis) + + # Orbit start camera with current angle and axis + # Rotate viewing direction + rotation = transform.Rotate(numpy.degrees(angle), *axis) + direction = rotation.transformDir(self._startExtrinsic.direction) + up = rotation.transformDir(self._startExtrinsic.up) + + # Rotate position around center + trlist = transform.StaticTransformList(( + self._center, + rotation, + self._center.inverse())) + position = trlist.transformPoint(self._startExtrinsic.position) + + camerapos = self._viewport.camera.extrinsic + camerapos.setOrientation(direction, up) + camerapos.position = position + + def endDrag(self, x, y): + self._reset() + + +# CameraSelectPan ############################################################# + +class CameraSelectPan(ClickOrDrag): + """Picking on click and pan camera on drag.""" + + def __init__(self, viewport, button=LEFT_BTN, selectCB=None): + self._viewport = viewport + self._selectCB = selectCB + self._lastPosNdc = None + super(CameraSelectPan, self).__init__(button) + + def click(self, x, y): + if self._selectCB is not None: + ndcZ = self._viewport._pickNdcZGL(x, y) + position = self._viewport._getXZYGL(x, y) + # This assume no object lie on the far plane + # Alternative, change the depth range so that far is < 1 + if ndcZ != 1. and position is not None: + self._selectCB((x, y, ndcZ), position) + + def beginDrag(self, x, y): + ndc = self._viewport.windowToNdc(x, y) + ndcZ = self._viewport._pickNdcZGL(x, y) + # ndcZ is the panning plane + if ndc is not None and ndcZ is not None: + self._lastPosNdc = numpy.array((ndc[0], ndc[1], ndcZ, 1.), + dtype=numpy.float32) + else: + self._lastPosNdc = None + + def drag(self, x, y): + if self._lastPosNdc is not None: + ndc = self._viewport.windowToNdc(x, y) + if ndc is not None: + ndcPos = numpy.array((ndc[0], ndc[1], self._lastPosNdc[2], 1.), + dtype=numpy.float32) + + # Convert last and current NDC positions to scene coords + scenePos = self._viewport.camera.transformPoint( + ndcPos, direct=False, perspectiveDivide=True) + lastScenePos = self._viewport.camera.transformPoint( + self._lastPosNdc, direct=False, perspectiveDivide=True) + + # Get translation in scene coords + translation = scenePos[:3] - lastScenePos[:3] + self._viewport.camera.extrinsic.position -= translation + + # Store for next drag + self._lastPosNdc = ndcPos + + def endDrag(self, x, y): + self._lastPosNdc = None + + +# CameraWheel ################################################################# + +class CameraWheel(object): + """StateMachine like class, just handling wheel events.""" + + # TODO choose scale of motion? Translation or Scale? + def __init__(self, viewport, mode='center', scaleTransform=None): + assert mode in ('center', 'position', 'scale') + self._viewport = viewport + if mode == 'center': + self._zoomTo = self._zoomToCenter + elif mode == 'position': + self._zoomTo = self._zoomToPosition + elif mode == 'scale': + self._zoomTo = self._zoomByScale + self._scale = scaleTransform + else: + raise ValueError('Unsupported mode: %s' % mode) + + def handleEvent(self, eventName, *args, **kwargs): + if eventName == 'wheel': + return self._zoomTo(*args, **kwargs) + + def _zoomToCenter(self, x, y, angleInDegrees): + """Zoom to center of display. + + Only works with perspective camera. + """ + direction = 'forward' if angleInDegrees > 0 else 'backward' + self._viewport.camera.move(direction) + return True + + def _zoomToPositionAbsolute(self, x, y, angleInDegrees): + """Zoom while keeping pixel under mouse invariant. + + Only works with perspective camera. + """ + ndc = self._viewport.windowToNdc(x, y) + if ndc is not None: + near = numpy.array((ndc[0], ndc[1], -1., 1.), dtype=numpy.float32) + + nearscene = self._viewport.camera.transformPoint( + near, direct=False, perspectiveDivide=True) + + far = numpy.array((ndc[0], ndc[1], 1., 1.), dtype=numpy.float32) + farscene = self._viewport.camera.transformPoint( + far, direct=False, perspectiveDivide=True) + + dirscene = farscene[:3] - nearscene[:3] + dirscene /= numpy.linalg.norm(dirscene) + + if angleInDegrees < 0: + dirscene *= -1. + + # TODO which scale + self._viewport.camera.extrinsic.position += dirscene + return True + + def _zoomToPosition(self, x, y, angleInDegrees): + """Zoom while keeping pixel under mouse invariant.""" + projection = self._viewport.camera.intrinsic + extrinsic = self._viewport.camera.extrinsic + + if isinstance(projection, transform.Perspective): + # For perspective projection, move camera + ndc = self._viewport.windowToNdc(x, y) + if ndc is not None: + ndcz = self._viewport._pickNdcZGL(x, y) + + position = numpy.array((ndc[0], ndc[1], ndcz), + dtype=numpy.float32) + positionscene = self._viewport.camera.transformPoint( + position, direct=False, perspectiveDivide=True) + + camtopos = extrinsic.position - positionscene + + step = 0.2 * (1. if angleInDegrees < 0 else -1.) + extrinsic.position += step * camtopos + + elif isinstance(projection, transform.Orthographic): + # For orthographic projection, change projection borders + ndcx, ndcy = self._viewport.windowToNdc(x, y, checkInside=False) + + step = 0.2 * (1. if angleInDegrees < 0 else -1.) + + dx = (ndcx + 1) / 2. + stepwidth = step * (projection.right - projection.left) + left = projection.left - dx * stepwidth + right = projection.right + (1. - dx) * stepwidth + + dy = (ndcy + 1) / 2. + stepheight = step * (projection.top - projection.bottom) + bottom = projection.bottom - dy * stepheight + top = projection.top + (1. - dy) * stepheight + + projection.setClipping(left, right, bottom, top) + + else: + raise RuntimeError('Unsupported camera', projection) + return True + + def _zoomByScale(self, x, y, angleInDegrees): + """Zoom by scaling scene (do not keep pixel under mouse invariant).""" + scalefactor = 1.1 + if angleInDegrees < 0.: + scalefactor = 1. / scalefactor + self._scale.scale = scalefactor * self._scale.scale + + self._viewport.adjustCameraDepthExtent() + return True + + +# FocusManager ################################################################ + +class FocusManager(StateMachine): + """Manages focus across multiple event handlers + + On press an event handler can acquire focus. + By default it looses focus when all buttons are released. + """ + class Idle(State): + def onPress(self, x, y, btn): + for eventHandler in self.machine.eventHandlers: + requestfocus = eventHandler.handleEvent('press', x, y, btn) + if requestfocus: + self.goto('focus', eventHandler, btn) + break + + def _processEvent(self, *args): + for eventHandler in self.machine.eventHandlers: + consumeevent = eventHandler.handleEvent(*args) + if consumeevent: + break + + def onMove(self, x, y): + self._processEvent('move', x, y) + + def onRelease(self, x, y, btn): + self._processEvent('release', x, y, btn) + + def onWheel(self, x, y, angle): + self._processEvent('wheel', x, y, angle) + + class Focus(State): + def enterState(self, eventHandler, btn): + self.eventHandler = eventHandler + self.focusBtns = {btn} # Set + + enter = enterState # silx v.0.3 support, remove when 0.4 out + + def onPress(self, x, y, btn): + self.focusBtns.add(btn) + self.eventHandler.handleEvent('press', x, y, btn) + + def onMove(self, x, y): + self.eventHandler.handleEvent('move', x, y) + + def onRelease(self, x, y, btn): + self.focusBtns.discard(btn) + requestfocus = self.eventHandler.handleEvent('release', x, y, btn) + if len(self.focusBtns) == 0 and not requestfocus: + self.goto('idle') + + def onWheel(self, x, y, angleInDegrees): + self.eventHandler.handleEvent('wheel', x, y, angleInDegrees) + + def __init__(self, eventHandlers=()): + self.eventHandlers = list(eventHandlers) + + states = { + 'idle': FocusManager.Idle, + 'focus': FocusManager.Focus + } + super(FocusManager, self).__init__(states, 'idle') + + def cancel(self): + for handler in self.eventHandlers: + handler.cancel() + + +# CameraControl ############################################################### + +class CameraControl(FocusManager): + """Combine wheel, selectPan and rotate state machine.""" + def __init__(self, viewport, + orbitAroundCenter=False, + mode='center', scaleTransform=None, + selectCB=None): + handlers = (CameraWheel(viewport, mode, scaleTransform), + CameraSelectPan(viewport, LEFT_BTN, selectCB), + CameraRotate(viewport, orbitAroundCenter, RIGHT_BTN)) + super(CameraControl, self).__init__(handlers) + + +# PlaneRotate ################################################################# + +class PlaneRotate(ClickOrDrag): + """Plane rotation using arcball interaction. + + Arcball ref.: + Ken Shoemake. ARCBALL: A user interface for specifying three-dimensional + orientation using a mouse. In Proc. GI '92. (1992). pp. 151-156. + """ + + def __init__(self, viewport, plane, button=RIGHT_BTN): + self._viewport = viewport + self._plane = plane + self._reset() + super(PlaneRotate, self).__init__(button) + + def _reset(self): + self._beginNormal, self._beginCenter = None, None + + def click(self, x, y): + pass # No interaction + + @staticmethod + def _sphereUnitVector(radius, center, position): + """Returns the unit vector of the projection of position on a sphere. + + It assumes an orthographic projection. + For perspective projection, it gives an approximation, but it + simplifies computations and results in consistent arcball control + in control space. + + All parameters must be in screen coordinate system + (either pixels or normalized coordinates). + + :param float radius: The radius of the sphere. + :param center: (x, y) coordinates of the center. + :param position: (x, y) coordinates of the cursor position. + :return: Unit vector. + :rtype: numpy.ndarray of 3 floats. + """ + center, position = numpy.array(center), numpy.array(position) + + # Normalize x and y on a unit circle + spherecoords = (position - center) / float(radius) + squarelength = numpy.sum(spherecoords ** 2) + + # Project on the unit sphere and compute z coordinates + if squarelength > 1.0: # Outside sphere: project + spherecoords /= numpy.sqrt(squarelength) + zsphere = 0.0 + else: # In sphere: compute z + zsphere = numpy.sqrt(1. - squarelength) + + spherecoords = numpy.append(spherecoords, zsphere) + return spherecoords + + def beginDrag(self, x, y): + # Makes sure the point defining the plane is at the center as + # it will be the center of rotation (as rotation is applied to normal) + self._plane.plane.point = self._plane.center + + # Store the plane normal + self._beginNormal = self._plane.plane.normal + + _logger.debug( + 'Begin arcball, plane center %s', str(self._plane.center)) + + # Do the arcball on the screen + radius = min(self._viewport.size) + if self._plane.center is None: + self._beginCenter = None + + else: + center = self._plane.objectToNDCTransform.transformPoint( + self._plane.center, perspectiveDivide=True) + self._beginCenter = self._viewport.ndcToWindow( + center[0], center[1], checkInside=False) + + self._startVector = self._sphereUnitVector( + radius, self._beginCenter, (x, y)) + + def drag(self, x, y): + if self._beginCenter is None: + return + + # Compute rotation: this is twice the rotation of the arcball + radius = min(self._viewport.size) + currentvector = self._sphereUnitVector( + radius, self._beginCenter, (x, y)) + crossprod = numpy.cross(self._startVector, currentvector) + dotprod = numpy.dot(self._startVector, currentvector) + + quaternion = numpy.append(crossprod, dotprod) + # Rotation was computed with Y downward, but apply in NDC, invert Y + quaternion[1] *= -1. + + rotation = transform.Rotate() + rotation.quaternion = quaternion + + # Convert to NDC, rotate, convert back to object + normal = self._plane.objectToNDCTransform.transformNormal( + self._beginNormal) + normal = rotation.transformNormal(normal) + normal = self._plane.objectToNDCTransform.transformNormal( + normal, direct=False) + self._plane.plane.normal = normal + + def endDrag(self, x, y): + self._reset() + + +# PlanePan ################################################################### + +class PlanePan(ClickOrDrag): + """Pan a plane along its normal on drag.""" + + def __init__(self, viewport, plane, button=LEFT_BTN): + self._plane = plane + self._viewport = viewport + self._beginPlanePoint = None + self._beginPos = None + self._dragNdcZ = 0. + super(PlanePan, self).__init__(button) + + def click(self, x, y): + pass + + def beginDrag(self, x, y): + ndc = self._viewport.windowToNdc(x, y) + ndcZ = self._viewport._pickNdcZGL(x, y) + # ndcZ is the panning plane + if ndc is not None and ndcZ is not None: + ndcPos = numpy.array((ndc[0], ndc[1], ndcZ, 1.), + dtype=numpy.float32) + scenePos = self._viewport.camera.transformPoint( + ndcPos, direct=False, perspectiveDivide=True) + self._beginPos = self._plane.objectToSceneTransform.transformPoint( + scenePos, direct=False) + self._dragNdcZ = ndcZ + else: + self._beginPos = None + self._dragNdcZ = 0. + + self._beginPlanePoint = self._plane.plane.point + + def drag(self, x, y): + if self._beginPos is not None: + ndc = self._viewport.windowToNdc(x, y) + if ndc is not None: + ndcPos = numpy.array((ndc[0], ndc[1], self._dragNdcZ, 1.), + dtype=numpy.float32) + + # Convert last and current NDC positions to scene coords + scenePos = self._viewport.camera.transformPoint( + ndcPos, direct=False, perspectiveDivide=True) + curPos = self._plane.objectToSceneTransform.transformPoint( + scenePos, direct=False) + + # Get translation in scene coords + translation = curPos[:3] - self._beginPos[:3] + + newPoint = self._beginPlanePoint + translation + + # Keep plane point in bounds + bounds = self._plane.parent.bounds(dataBounds=True) + if bounds is not None: + newPoint = numpy.clip( + newPoint, a_min=bounds[0], a_max=bounds[1]) + + # Only update plane if it is in some bounds + self._plane.plane.point = newPoint + + def endDrag(self, x, y): + self._beginPlanePoint = None + + +# PlaneControl ################################################################ + +class PlaneControl(FocusManager): + """Combine wheel, selectPan and rotate state machine for plane control.""" + def __init__(self, viewport, plane, + mode='center', scaleTransform=None): + handlers = (CameraWheel(viewport, mode, scaleTransform), + PlanePan(viewport, plane, LEFT_BTN), + PlaneRotate(viewport, plane, RIGHT_BTN)) + super(PlaneControl, self).__init__(handlers) + + +class PanPlaneRotateCameraControl(FocusManager): + """Combine wheel, pan plane and camera rotate state machine.""" + def __init__(self, viewport, plane, + mode='center', scaleTransform=None): + handlers = (CameraWheel(viewport, mode, scaleTransform), + PlanePan(viewport, plane, LEFT_BTN), + CameraRotate(viewport, + orbitAroundCenter=False, + button=RIGHT_BTN)) + super(PanPlaneRotateCameraControl, self).__init__(handlers) diff --git a/silx/gui/plot3d/scene/primitives.py b/silx/gui/plot3d/scene/primitives.py new file mode 100644 index 0000000..ca2616a --- /dev/null +++ b/silx/gui/plot3d/scene/primitives.py @@ -0,0 +1,1764 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "25/07/2016" + + +import collections +import ctypes +from functools import reduce +import logging +import string + +import numpy + +from silx.gui.plot.Colors import rgba + +from ... import _glutils +from ..._glutils import gl + +from . import event +from . import core +from . import transform +from . import utils + +_logger = logging.getLogger(__name__) + + +# Geometry #################################################################### + +class Geometry(core.Elem): + """Set of vertices with normals and colors. + + :param str mode: OpenGL drawing mode: + lines, line_strip, loop, triangles, triangle_strip, fan + :param indices: Array of vertex indices or None + :param bool copy: True (default) to copy the data, False to use as is. + :param attributes: Provide list of attributes as extra parameters. + """ + + _ATTR_INFO = { + 'position': {'dims': (1, 2), 'lastDim': (2, 3, 4)}, + 'normal': {'dims': (1, 2), 'lastDim': (3,)}, + 'color': {'dims': (1, 2), 'lastDim': (3, 4)}, + } + + _MODE_CHECKS = { # Min, Modulo + 'lines': (2, 2), 'line_strip': (2, 0), 'loop': (2, 0), + 'points': (1, 0), + 'triangles': (3, 3), 'triangle_strip': (3, 0), 'fan': (3, 0) + } + + _MODES = { + 'lines': gl.GL_LINES, + 'line_strip': gl.GL_LINE_STRIP, + 'loop': gl.GL_LINE_LOOP, + + 'points': gl.GL_POINTS, + + 'triangles': gl.GL_TRIANGLES, + 'triangle_strip': gl.GL_TRIANGLE_STRIP, + 'fan': gl.GL_TRIANGLE_FAN + } + + _LINE_MODES = 'lines', 'line_strip', 'loop' + + _TRIANGLE_MODES = 'triangles', 'triangle_strip', 'fan' + + def __init__(self, mode, indices=None, copy=True, **attributes): + super(Geometry, self).__init__() + + self._vbos = {} # Store current vbos + self._unsyncAttributes = [] # Store attributes to copy to vbos + self.__bounds = None # Cache object's bounds + + assert mode in self._MODES + self._mode = mode + + # Set attributes + self._attributes = {} + for name, data in attributes.items(): + self.setAttribute(name, data, copy=copy) + + # Set indices + self._indices = None + self.setIndices(indices, copy=copy) + + # More consistency checks + mincheck, modulocheck = self._MODE_CHECKS[self._mode] + if self._indices is not None: + nbvertices = len(self._indices) + else: + nbvertices = self.nbVertices + assert nbvertices >= mincheck + if modulocheck != 0: + assert (nbvertices % modulocheck) == 0 + + @staticmethod + def _glReadyArray(array, copy=True): + """Making a contiguous array, checking float types. + + :param iterable array: array-like data to prepare for attribute + :param bool copy: True to make a copy of the array, False to use as is + """ + # Convert single value (int, float, numpy types) to tuple + if not isinstance(array, collections.Iterable): + array = (array, ) + + # Makes sure it is an array + array = numpy.array(array, copy=False) + + # Cast all float to float32 + dtype = None + if numpy.dtype(array.dtype).kind == 'f': + dtype = numpy.float32 + + return numpy.array(array, dtype=dtype, order='C', copy=copy) + + @property + def nbVertices(self): + """Returns the number of vertices of current attributes. + + It returns None if there is no attributes. + """ + for array in self._attributes.values(): + if len(array.shape) == 2: + return len(array) + return None + + def setAttribute(self, name, array, copy=True): + """Set attribute with provided array. + + :param str name: The name of the attribute + :param array: Array-like attribute data or None to remove attribute + :param bool copy: True (default) to copy the data, False to use as is + """ + # This triggers associated GL resources to be garbage collected + self._vbos.pop(name, None) + + if array is None: + self._attributes.pop(name, None) + + else: + array = self._glReadyArray(array, copy=copy) + + if name not in self._ATTR_INFO: + _logger.info('Not checking attibute %s dimensions', name) + else: + checks = self._ATTR_INFO[name] + + if (len(array.shape) == 1 and checks['lastDim'] == (1,) and + len(array) > 1): + array = array.reshape((len(array), 1)) + + # Checks + assert len(array.shape) in checks['dims'], "Attr %s" % name + assert array.shape[-1] in checks['lastDim'], "Attr %s" % name + + # Check length against another attribute array + # Causes problems when updating + # nbVertices = self.nbVertices + # if len(array.shape) == 2 and nbVertices is not None: + # assert len(array) == nbVertices + + self._attributes[name] = array + if len(array.shape) == 2: # Store this in a VBO + self._unsyncAttributes.append(name) + + if name == 'position': # Reset bounds + self.__bounds = None + + self.notify() + + def getAttribute(self, name, copy=True): + """Returns the numpy.ndarray corresponding to the name attribute. + + :param str name: The name of the attribute to get. + :param bool copy: True to get a copy (default), + False to get internal array (DO NOT MODIFY) + :return: The corresponding array or None if no corresponding attribute. + :rtype: numpy.ndarray + """ + attr = self._attributes.get(name, None) + return None if attr is None else numpy.array(attr, copy=copy) + + def useAttribute(self, program, name=None): + """Enable and bind attribute(s) for a specific program. + + This MUST be called with OpenGL context active and after prepareGL2 + has been called. + + :param GLProgram program: The program for which to set the attributes + :param str name: The attribute name to set or None to set then all + """ + if name is None: + for name in program.attributes: + self.useAttribute(program, name) + + else: + attribute = program.attributes.get(name) + if attribute is None: + return + + vboattrib = self._vbos.get(name) + if vboattrib is not None: + gl.glEnableVertexAttribArray(attribute) + vboattrib.setVertexAttrib(attribute) + + elif name not in self._attributes: + gl.glDisableVertexAttribArray(attribute) + + else: + array = self._attributes[name] + assert array is not None + + if len(array.shape) == 1: + assert len(array) in (1, 2, 3, 4) + gl.glDisableVertexAttribArray(attribute) + _glVertexAttribFunc = getattr( + _glutils.gl, 'glVertexAttrib{}f'.format(len(array))) + _glVertexAttribFunc(attribute, *array) + else: + # TODO As is this is a never event, remove? + gl.glEnableVertexAttribArray(attribute) + gl.glVertexAttribPointer( + attribute, + array.shape[-1], + _glutils.numpyToGLType(array.dtype), + gl.GL_FALSE, + 0, + array) + + def setIndices(self, indices, copy=True): + """Set the primitive indices to use. + + :param indices: Array-like of uint primitive indices or None to unset + :param bool copy: True (default) to copy the data, False to use as is + """ + # Trigger garbage collection of previous indices VBO if any + self._vbos.pop('__indices__', None) + + if indices is None: + self._indices = None + else: + indices = self._glReadyArray(indices, copy=copy).ravel() + assert indices.dtype.name in ('uint8', 'uint16', 'uint32') + if _logger.getEffectiveLevel() <= logging.DEBUG: + # This might be a costy check + assert indices.max() < self.nbVertices + self._indices = indices + + def getIndices(self, copy=True): + """Returns the numpy.ndarray corresponding to the indices. + + :param bool copy: True to get a copy (default), + False to get internal array (DO NOT MODIFY) + :return: The primitive indices array or None if not set. + :rtype: numpy.ndarray or None + """ + if self._indices is None: + return None + else: + return numpy.array(self._indices, copy=copy) + + def _bounds(self, dataBounds=False): + if self.__bounds is None: + self.__bounds = numpy.zeros((2, 3), dtype=numpy.float32) + # Support vertex with to 2 to 4 coordinates + positions = self._attributes['position'] + self.__bounds[0, :positions.shape[1]] = positions.min(axis=0)[:3] + self.__bounds[1, :positions.shape[1]] = positions.max(axis=0)[:3] + return self.__bounds.copy() + + def prepareGL2(self, ctx): + # TODO manage _vbo and multiple GL context + allow to share them ! + # TODO make one or multiple VBO depending on len(vertices), + # TODO use a general common VBO for small amount of data + for name in self._unsyncAttributes: + array = self._attributes[name] + self._vbos[name] = ctx.glCtx.makeVboAttrib(array) + self._unsyncAttributes = [] + + if self._indices is not None and '__indices__' not in self._vbos: + vbo = ctx.glCtx.makeVbo(self._indices, + usage=gl.GL_STATIC_DRAW, + target=gl.GL_ELEMENT_ARRAY_BUFFER) + self._vbos['__indices__'] = vbo + + def _draw(self, program=None, nbVertices=None): + """Perform OpenGL draw calls. + + :param GLProgram program: + If not None, call :meth:`useAttribute` for this program. + :param int nbVertices: + The number of vertices to render or None to render all vertices. + """ + if program is not None: + self.useAttribute(program) + + if self._indices is None: + if nbVertices is None: + nbVertices = self.nbVertices + gl.glDrawArrays(self._MODES[self._mode], 0, nbVertices) + else: + if nbVertices is None: + nbVertices = self._indices.size + with self._vbos['__indices__']: + gl.glDrawElements(self._MODES[self._mode], + nbVertices, + _glutils.numpyToGLType(self._indices.dtype), + ctypes.c_void_p(0)) + + +# Lines ####################################################################### + +class Lines(Geometry): + """A set of segments""" + _shaders = (""" + attribute vec3 position; + attribute vec3 normal; + attribute vec4 color; + + uniform mat4 matrix; + uniform mat4 transformMat; + + varying vec4 vCameraPosition; + varying vec3 vPosition; + varying vec3 vNormal; + varying vec4 vColor; + + void main(void) + { + gl_Position = matrix * vec4(position, 1.0); + vCameraPosition = transformMat * vec4(position, 1.0); + vPosition = position; + vNormal = normal; + vColor = color; + } + """, + string.Template(""" + varying vec4 vCameraPosition; + varying vec3 vPosition; + varying vec3 vNormal; + varying vec4 vColor; + + $clippingDecl + $lightingFunction + + void main(void) + { + $clippingCall(vCameraPosition); + gl_FragColor = $lightingCall(vColor, vPosition, vNormal); + } + """)) + + def __init__(self, positions, normals=None, colors=(1., 1., 1., 1.), + indices=None, mode='lines', width=1.): + if mode == 'strip': + mode = 'line_strip' + assert mode in self._LINE_MODES + + self._width = width + self._smooth = True + + super(Lines, self).__init__(mode, indices, + position=positions, + normal=normals, + color=colors) + + width = event.notifyProperty('_width', converter=float, + doc="Width of the line in pixels.") + + smooth = event.notifyProperty( + '_smooth', + converter=bool, + doc="Smooth line rendering enabled (bool, default: True)") + + def renderGL2(self, ctx): + # Prepare program + isnormals = 'normal' in self._attributes + if isnormals: + fraglightfunction = ctx.viewport.light.fragmentDef + else: + fraglightfunction = ctx.viewport.light.fragmentShaderFunctionNoop + + fragment = self._shaders[1].substitute( + clippingDecl=ctx.clipper.fragDecl, + clippingCall=ctx.clipper.fragCall, + lightingFunction=fraglightfunction, + lightingCall=ctx.viewport.light.fragmentCall) + prog = ctx.glCtx.prog(self._shaders[0], fragment) + prog.use() + + if isnormals: + ctx.viewport.light.setupProgram(ctx, prog) + + gl.glLineWidth(self.width) + + prog.setUniformMatrix('matrix', ctx.objectToNDC.matrix) + prog.setUniformMatrix('transformMat', + ctx.objectToCamera.matrix, + safe=True) + + ctx.clipper.setupProgram(ctx, prog) + + with gl.enabled(gl.GL_LINE_SMOOTH, self._smooth): + self._draw(prog) + + +class DashedLines(Lines): + """Set of dashed lines + + This MUST be defined as a set of lines (no strip or loop). + """ + + _shaders = (""" + attribute vec3 position; + attribute vec3 origin; + attribute vec3 normal; + attribute vec4 color; + + uniform mat4 matrix; + uniform mat4 transformMat; + uniform vec2 viewportSize; /* Width, height of the viewport */ + + varying vec4 vCameraPosition; + varying vec3 vPosition; + varying vec3 vNormal; + varying vec4 vColor; + varying vec2 vOriginFragCoord; + + void main(void) + { + gl_Position = matrix * vec4(position, 1.0); + vCameraPosition = transformMat * vec4(position, 1.0); + vPosition = position; + vNormal = normal; + vColor = color; + + vec4 clipOrigin = matrix * vec4(origin, 1.0); + vec4 ndcOrigin = clipOrigin / clipOrigin.w; /* Perspective divide */ + /* Convert to same frame as gl_FragCoord: lower-left, pixel center at 0.5, 0.5 */ + vOriginFragCoord = (ndcOrigin.xy + vec2(1.0, 1.0)) * 0.5 * viewportSize + vec2(0.5, 0.5); + } + """, # noqa + string.Template(""" + varying vec4 vCameraPosition; + varying vec3 vPosition; + varying vec3 vNormal; + varying vec4 vColor; + varying vec2 vOriginFragCoord; + + uniform vec2 dash; + + $clippingDecl + $lightingFunction + + void main(void) + { + /* Discard off dash fragments */ + float lineDist = distance(vOriginFragCoord, gl_FragCoord.xy); + if (mod(lineDist, dash.x + dash.y) > dash.x) { + discard; + } + $clippingCall(vCameraPosition); + gl_FragColor = $lightingCall(vColor, vPosition, vNormal); + } + """)) + + def __init__(self, positions, colors=(1., 1., 1., 1.), + indices=None, width=1.): + self._dash = 1, 0 + super(DashedLines, self).__init__(positions=positions, + colors=colors, + indices=indices, + mode='lines', + width=width) + + @property + def dash(self): + """Dash of the line as a 2-tuple of lengths in pixels: (on, off)""" + return self._dash + + @dash.setter + def dash(self, dash): + dash = float(dash[0]), float(dash[1]) + if dash != self._dash: + self._dash = dash + self.notify() + + def getPositions(self, copy=True): + """Get coordinates of lines. + + :param bool copy: True to get a copy, False otherwise + :returns: Coordinates of lines + :rtype: numpy.ndarray of float32 of shape (N, 2, Ndim) + """ + return self.getAttribute('position', copy=copy) + + def setPositions(self, positions, copy=True): + """Set line coordinates. + + :param positions: Array of line coordinates + :param bool copy: True to copy input array, False to use as is + """ + self.setAttribute('position', positions, copy=copy) + # Update line origins from given positions + origins = numpy.array(positions, copy=True, order='C') + origins[1::2] = origins[::2] + self.setAttribute('origin', origins, copy=False) + + def renderGL2(self, context): + # Prepare program + isnormals = 'normal' in self._attributes + if isnormals: + fraglightfunction = context.viewport.light.fragmentDef + else: + fraglightfunction = \ + context.viewport.light.fragmentShaderFunctionNoop + + fragment = self._shaders[1].substitute( + clippingDecl=context.clipper.fragDecl, + clippingCall=context.clipper.fragCall, + lightingFunction=fraglightfunction, + lightingCall=context.viewport.light.fragmentCall) + program = context.glCtx.prog(self._shaders[0], fragment) + program.use() + + if isnormals: + context.viewport.light.setupProgram(context, program) + + gl.glLineWidth(self.width) + + program.setUniformMatrix('matrix', context.objectToNDC.matrix) + program.setUniformMatrix('transformMat', + context.objectToCamera.matrix, + safe=True) + + gl.glUniform2f( + program.uniforms['viewportSize'], *context.viewport.size) + gl.glUniform2f(program.uniforms['dash'], *self.dash) + + context.clipper.setupProgram(context, program) + + self._draw(program) + + +class Box(core.PrivateGroup): + """Rectangular box""" + + _lineIndices = numpy.array(( + (0, 1), (1, 2), (2, 3), (3, 0), # Lines with z=0 + (0, 4), (1, 5), (2, 6), (3, 7), # Lines from z=0 to z=1 + (4, 5), (5, 6), (6, 7), (7, 4)), # Lines with z=1 + dtype=numpy.uint8) + + _faceIndices = numpy.array( + (0, 3, 1, 2, 5, 6, 4, 7, 7, 6, 6, 2, 7, 3, 4, 0, 5, 1), + dtype=numpy.uint8) + + _vertices = numpy.array(( + # Corners with z=0 + (0., 0., 0.), (1., 0., 0.), (1., 1., 0.), (0., 1., 0.), + # Corners with z=1 + (0., 0., 1.), (1., 0., 1.), (1., 1., 1.), (0., 1., 1.)), + dtype=numpy.float32) + + def __init__(self, size=(1., 1., 1.), + stroke=(1., 1., 1., 1.), + fill=(1., 1., 1., 0.)): + super(Box, self).__init__() + + self._fill = Mesh3D(self._vertices, + colors=rgba(fill), + mode='triangle_strip', + indices=self._faceIndices) + self._fill.visible = self.fillColor[-1] != 0. + + self._stroke = Lines(self._vertices, + indices=self._lineIndices, + colors=rgba(stroke), + mode='lines') + self._stroke.visible = self.strokeColor[-1] != 0. + self.strokeWidth = 1. + + self._children = [self._stroke, self._fill] + + self._size = None + self.size = size + + @property + def size(self): + """Size of the box (sx, sy, sz)""" + return self._size + + @size.setter + def size(self, size): + assert len(size) == 3 + size = tuple(size) + if size != self.size: + self._size = size + self._fill.setAttribute( + 'position', + self._vertices * numpy.array(size, dtype=numpy.float32)) + self._stroke.setAttribute( + 'position', + self._vertices * numpy.array(size, dtype=numpy.float32)) + self.notify() + + @property + def strokeSmooth(self): + """True to draw smooth stroke, False otherwise""" + return self._stroke.smooth + + @strokeSmooth.setter + def strokeSmooth(self, smooth): + smooth = bool(smooth) + if smooth != self._stroke.smooth: + self._stroke.smooth = smooth + self.notify() + + @property + def strokeWidth(self): + """Width of the stroke (float)""" + return self._stroke.width + + @strokeWidth.setter + def strokeWidth(self, width): + width = float(width) + if width != self.strokeWidth: + self._stroke.width = width + self.notify() + + @property + def strokeColor(self): + """RGBA color of the box lines (4-tuple of float in [0, 1])""" + return tuple(self._stroke.getAttribute('color', copy=False)) + + @strokeColor.setter + def strokeColor(self, color): + color = rgba(color) + if color != self.strokeColor: + self._stroke.setAttribute('color', color) + # Fully transparent = hidden + self._stroke.visible = color[-1] != 0. + self.notify() + + @property + def fillColor(self): + """RGBA color of the box faces (4-tuple of float in [0, 1])""" + return tuple(self._fill.getAttribute('color', copy=False)) + + @fillColor.setter + def fillColor(self, color): + color = rgba(color) + if color != self.fillColor: + self._fill.setAttribute('color', color) + # Fully transparent = hidden + self._fill.visible = color[-1] != 0. + self.notify() + + @property + def fillCulling(self): + return self._fill.culling + + @fillCulling.setter + def fillCulling(self, culling): + self._fill.culling = culling + + +class Axes(Lines): + """3D RGB orthogonal axes""" + _vertices = numpy.array(((0., 0., 0.), (1., 0., 0.), + (0., 0., 0.), (0., 1., 0.), + (0., 0., 0.), (0., 0., 1.)), + dtype=numpy.float32) + + _colors = numpy.array(((255, 0, 0, 255), (255, 0, 0, 255), + (0, 255, 0, 255), (0, 255, 0, 255), + (0, 0, 255, 255), (0, 0, 255, 255)), + dtype=numpy.uint8) + + def __init__(self): + super(Axes, self).__init__(self._vertices, + colors=self._colors, + width=3.) + + +class BoxWithAxes(Lines): + """Rectangular box with RGB OX, OY, OZ axes + + :param color: RGBA color of the box + """ + + _vertices = numpy.array(( + # Axes corners + (0., 0., 0.), (1., 0., 0.), + (0., 0., 0.), (0., 1., 0.), + (0., 0., 0.), (0., 0., 1.), + # Box corners with z=0 + (1., 0., 0.), (1., 1., 0.), (0., 1., 0.), + # Box corners with z=1 + (0., 0., 1.), (1., 0., 1.), (1., 1., 1.), (0., 1., 1.)), + dtype=numpy.float32) + + _axesColors = numpy.array(((1., 0., 0., 1.), (1., 0., 0., 1.), + (0., 1., 0., 1.), (0., 1., 0., 1.), + (0., 0., 1., 1.), (0., 0., 1., 1.)), + dtype=numpy.float32) + + _lineIndices = numpy.array(( + (0, 1), (2, 3), (4, 5), # Axes lines + (6, 7), (7, 8), # Box lines with z=0 + (6, 10), (7, 11), (8, 12), # Box lines from z=0 to z=1 + (9, 10), (10, 11), (11, 12), (12, 9)), # Box lines with z=1 + dtype=numpy.uint8) + + def __init__(self, color=(1., 1., 1., 1.)): + self._color = (1., 1., 1., 1.) + colors = numpy.ones((len(self._vertices), 4), dtype=numpy.float32) + colors[:len(self._axesColors), :] = self._axesColors + + super(BoxWithAxes, self).__init__(self._vertices, + indices=self._lineIndices, + colors=colors, + width=2.) + self.color = color + + @property + def color(self): + """The RGBA color to use for the box: 4 float in [0, 1]""" + return self._color + + @color.setter + def color(self, color): + color = rgba(color) + if color != self._color: + self._color = color + colors = numpy.empty((len(self._vertices), 4), dtype=numpy.float32) + colors[:len(self._axesColors), :] = self._axesColors + colors[len(self._axesColors):, :] = self._color + self.setAttribute('color', colors) # Do the notification + + +class PlaneInGroup(core.PrivateGroup): + """A plane using its parent bounds to display a contour. + + If plane is outside the bounds of its parent, it is not visible. + + Cannot set the transform attribute of this primitive. + This primitive never has any bounds. + """ + # TODO inherit from Lines directly?, make sure the plane remains visible? + + def __init__(self, point=(0., 0., 0.), normal=(0., 0., 1.)): + super(PlaneInGroup, self).__init__() + self._cache = None, None # Store bounds, vertices + self._outline = None + + self._color = None + self.color = 1., 1., 1., 1. # Set _color + self._width = 2. + + self._plane = utils.Plane(point, normal) + self._plane.addListener(self._planeChanged) + + def moveToCenter(self): + """Place the plane at the center of the data, not changing orientation. + """ + if self.parent is not None: + bounds = self.parent.bounds(dataBounds=True) + if bounds is not None: + center = (bounds[0] + bounds[1]) / 2. + _logger.debug('Moving plane to center: %s', str(center)) + self.plane.point = center + + @property + def color(self): + """Plane outline color (array of 4 float in [0, 1]).""" + return self._color.copy() + + @color.setter + def color(self, color): + self._color = numpy.array(color, copy=True, dtype=numpy.float32) + if self._outline is not None: + self._outline.setAttribute('color', self._color) + self.notify() # This is OK as Lines are rebuild for each rendering + + @property + def width(self): + """Width of the plane stroke in pixels""" + return self._width + + @width.setter + def width(self, width): + self._width = float(width) + if self._outline is not None: + self._outline.width = self._width # Sync width + + # Plane access + + @property + def plane(self): + """The plane parameters in the frame of the object.""" + return self._plane + + def _planeChanged(self, source): + """Listener of plane changes: clear cache and notify listeners.""" + self._cache = None, None + self.notify() + + # Disable some scene features + + @property + def transforms(self): + # Ready-only transforms to prevent using it + return self._transforms + + def _bounds(self, dataBounds=False): + # This is bound less as it uses the bounds of its parent. + return None + + @property + def contourVertices(self): + """The vertices of the contour of the plane/bounds intersection.""" + parent = self.parent + if parent is None: + return None # No parent: no vertices + + bounds = parent.bounds(dataBounds=True) + if bounds is None: + return None # No bounds: no vertices + + # Check if cache is valid and return it + cachebounds, cachevertices = self._cache + if numpy.all(numpy.equal(bounds, cachebounds)): + return cachevertices + + # Cache is not OK, rebuild it + boxvertices = bounds[0] + Box._vertices.copy()*(bounds[1] - bounds[0]) + lineindices = Box._lineIndices + vertices = utils.boxPlaneIntersect( + boxvertices, lineindices, self.plane.normal, self.plane.point) + + self._cache = bounds, vertices if len(vertices) != 0 else None + + return self._cache[1] + + @property + def center(self): + """The center of the plane/bounds intersection points.""" + if not self.isValid: + return None + else: + return numpy.mean(self.contourVertices, axis=0) + + @property + def isValid(self): + """True if a contour is defined, False otherwise.""" + return self.plane.isPlane and self.contourVertices is not None + + def prepareGL2(self, ctx): + if self.isValid: + if self._outline is None: # Init outline + self._outline = Lines(self.contourVertices, + mode='loop', + colors=self.color) + self._outline.width = self._width + self._children.append(self._outline) + + # Update vertices, TODO only when necessary + self._outline.setAttribute('position', self.contourVertices) + + super(PlaneInGroup, self).prepareGL2(ctx) + + def renderGL2(self, ctx): + if self.isValid: + super(PlaneInGroup, self).renderGL2(ctx) + + +# Points ###################################################################### + +_POINTS_ATTR_INFO = Geometry._ATTR_INFO.copy() +_POINTS_ATTR_INFO.update(value={'dims': (1, 2), 'lastDim': (1,)}, + size={'dims': (1, 2), 'lastDim': (1,)}, + symbol={'dims': (1, 2), 'lastDim': (1,)}) + + +class Points(Geometry): + """A set of data points with an associated value and size.""" + _shaders = (""" + #version 120 + + attribute vec3 position; + attribute float symbol; + attribute float value; + attribute float size; + + uniform mat4 matrix; + uniform mat4 transformMat; + + uniform vec2 valRange; + + varying vec4 vCameraPosition; + varying float vSymbol; + varying float vNormValue; + varying float vSize; + + void main(void) + { + vSymbol = symbol; + + vNormValue = clamp((value - valRange.x) / (valRange.y - valRange.x), + 0.0, 1.0); + + bool isValueInRange = value >= valRange.x && value <= valRange.y; + if (isValueInRange) { + gl_Position = matrix * vec4(position, 1.0); + } else { + gl_Position = vec4(2.0, 0.0, 0.0, 1.0); /* Get clipped */ + } + vCameraPosition = transformMat * vec4(position, 1.0); + + gl_PointSize = size; + vSize = size; + } + """, + string.Template(""" + #version 120 + + varying vec4 vCameraPosition; + varying float vSize; + varying float vSymbol; + varying float vNormValue; + + $clippinDecl + + /* Circle */ + #define SYMBOL_CIRCLE 1.0 + + float alphaCircle(vec2 coord, float size) { + float radius = 0.5; + float r = distance(coord, vec2(0.5, 0.5)); + return clamp(size * (radius - r), 0.0, 1.0); + } + + /* Half lines */ + #define SYMBOL_H_LINE 2.0 + #define LEFT 1.0 + #define RIGHT 2.0 + #define SYMBOL_V_LINE 3.0 + #define UP 1.0 + #define DOWN 2.0 + + float alphaLine(vec2 coord, float size, float direction) + { + vec2 delta = abs(size * (coord - 0.5)); + + if (direction == SYMBOL_H_LINE) { + return (delta.y < 0.5) ? 1.0 : 0.0; + } + else if (direction == SYMBOL_H_LINE + LEFT) { + return (coord.x <= 0.5 && delta.y < 0.5) ? 1.0 : 0.0; + } + else if (direction == SYMBOL_H_LINE + RIGHT) { + return (coord.x >= 0.5 && delta.y < 0.5) ? 1.0 : 0.0; + } + else if (direction == SYMBOL_V_LINE) { + return (delta.x < 0.5) ? 1.0 : 0.0; + } + else if (direction == SYMBOL_V_LINE + UP) { + return (coord.y <= 0.5 && delta.x < 0.5) ? 1.0 : 0.0; + } + else if (direction == SYMBOL_V_LINE + DOWN) { + return (coord.y >= 0.5 && delta.x < 0.5) ? 1.0 : 0.0; + } + return 1.0; + } + + void main(void) + { + $clippingCall(vCameraPosition); + + gl_FragColor = vec4(0.5 * vNormValue + 0.5, 0.0, 0.0, 1.0); + + float alpha = 1.0; + float symbol = floor(vSymbol); + if (1 == 1) { //symbol == SYMBOL_CIRCLE) { + alpha = alphaCircle(gl_PointCoord, vSize); + } + else if (symbol >= SYMBOL_H_LINE && + symbol <= (SYMBOL_V_LINE + DOWN)) { + alpha = alphaLine(gl_PointCoord, vSize, symbol); + } + if (alpha == 0.0) { + discard; + } + gl_FragColor.a *= alpha; + } + """)) + + _ATTR_INFO = _POINTS_ATTR_INFO + + # TODO Add colormap, light? + + def __init__(self, vertices, values=0., sizes=1., indices=None, + symbols=0., + minValue=None, maxValue=None): + super(Points, self).__init__('points', indices, + position=vertices, + value=values, + size=sizes, + symbol=symbols) + + values = self._attributes['value'] + self._minValue = values.min() if minValue is None else minValue + self._maxValue = values.max() if maxValue is None else maxValue + + minValue = event.notifyProperty('_minValue') + maxValue = event.notifyProperty('_maxValue') + + def renderGL2(self, ctx): + fragment = self._shaders[1].substitute( + clippingDecl=ctx.clipper.fragDecl, + clippingCall=ctx.clipper.fragCall) + prog = ctx.glCtx.prog(self._shaders[0], fragment) + prog.use() + + gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2 + gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2 + # gl.glEnable(gl.GL_PROGRAM_POINT_SIZE) + + prog.setUniformMatrix('matrix', ctx.objectToNDC.matrix) + prog.setUniformMatrix('transformMat', + ctx.objectToCamera.matrix, + safe=True) + + ctx.clipper.setupProgram(ctx, prog) + + gl.glUniform2f(prog.uniforms['valRange'], self.minValue, self.maxValue) + + self._draw(prog) + + +class ColorPoints(Geometry): + """A set of points with an associated color and size.""" + + _shaders = (""" + #version 120 + + attribute vec3 position; + attribute float symbol; + attribute vec4 color; + attribute float size; + + uniform mat4 matrix; + uniform mat4 transformMat; + + varying vec4 vCameraPosition; + varying float vSymbol; + varying vec4 vColor; + varying float vSize; + + void main(void) + { + vCameraPosition = transformMat * vec4(position, 1.0); + vSymbol = symbol; + vColor = color; + gl_Position = matrix * vec4(position, 1.0); + gl_PointSize = size; + vSize = size; + } + """, + string.Template(""" + #version 120 + + varying vec4 vCameraPosition; + varying float vSize; + varying float vSymbol; + varying vec4 vColor; + + $clippingDecl; + + /* Circle */ + #define SYMBOL_CIRCLE 1.0 + + float alphaCircle(vec2 coord, float size) { + float radius = 0.5; + float r = distance(coord, vec2(0.5, 0.5)); + return clamp(size * (radius - r), 0.0, 1.0); + } + + /* Half lines */ + #define SYMBOL_H_LINE 2.0 + #define LEFT 1.0 + #define RIGHT 2.0 + #define SYMBOL_V_LINE 3.0 + #define UP 1.0 + #define DOWN 2.0 + + float alphaLine(vec2 coord, float size, float direction) + { + vec2 delta = abs(size * (coord - 0.5)); + + if (direction == SYMBOL_H_LINE) { + return (delta.y < 0.5) ? 1.0 : 0.0; + } + else if (direction == SYMBOL_H_LINE + LEFT) { + return (coord.x <= 0.5 && delta.y < 0.5) ? 1.0 : 0.0; + } + else if (direction == SYMBOL_H_LINE + RIGHT) { + return (coord.x >= 0.5 && delta.y < 0.5) ? 1.0 : 0.0; + } + else if (direction == SYMBOL_V_LINE) { + return (delta.x < 0.5) ? 1.0 : 0.0; + } + else if (direction == SYMBOL_V_LINE + UP) { + return (coord.y <= 0.5 && delta.x < 0.5) ? 1.0 : 0.0; + } + else if (direction == SYMBOL_V_LINE + DOWN) { + return (coord.y >= 0.5 && delta.x < 0.5) ? 1.0 : 0.0; + } + return 1.0; + } + + void main(void) + { + $clippingCall(vCameraPosition); + + gl_FragColor = vColor; + + float alpha = 1.0; + float symbol = floor(vSymbol); + if (1 == 1) { //symbol == SYMBOL_CIRCLE) { + alpha = alphaCircle(gl_PointCoord, vSize); + } + else if (symbol >= SYMBOL_H_LINE && + symbol <= (SYMBOL_V_LINE + DOWN)) { + alpha = alphaLine(gl_PointCoord, vSize, symbol); + } + if (alpha == 0.0) { + discard; + } + gl_FragColor.a *= alpha; + } + """)) + + _ATTR_INFO = _POINTS_ATTR_INFO + + def __init__(self, vertices, colors=(1., 1., 1., 1.), sizes=1., + indices=None, symbols=0., + minValue=None, maxValue=None): + super(ColorPoints, self).__init__('points', indices, + position=vertices, + color=colors, + size=sizes, + symbol=symbols) + + def renderGL2(self, ctx): + fragment = self._shaders[1].substitute( + clippingDecl=ctx.clipper.fragDecl, + clippingCall=ctx.clipper.fragCall) + prog = ctx.glCtx.prog(self._shaders[0], fragment) + prog.use() + + gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2 + gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2 + # gl.glEnable(gl.GL_PROGRAM_POINT_SIZE) + + prog.setUniformMatrix('matrix', ctx.objectToNDC.matrix) + prog.setUniformMatrix('transformMat', + ctx.objectToCamera.matrix, + safe=True) + + ctx.clipper.setupProgram(ctx, prog) + + self._draw(prog) + + +class GridPoints(Geometry): + # GLSL 1.30 ! + """Data points on a regular grid with an associated value and size.""" + _shaders = (""" + #version 130 + + in float value; + in float size; + + uniform ivec3 gridDims; + uniform mat4 matrix; + uniform mat4 transformMat; + uniform vec2 valRange; + + out vec4 vCameraPosition; + out float vNormValue; + + //ivec3 coordsFromIndex(int index, ivec3 shape) + //{ + /*Assumes that data is stored as z-major, then y, contiguous on x + */ + // int yxPlaneSize = shape.y * shape.x; /* nb of elem in 2d yx plane */ + // int z = index / yxPlaneSize; + // int yxIndex = index - z * yxPlaneSize; /* index in 2d yx plane */ + // int y = yxIndex / shape.x; + // int x = yxIndex - y * shape.x; + // return ivec3(x, y, z); + // } + + ivec3 coordsFromIndex(int index, ivec3 shape) + { + /*Assumes that data is stored as x-major, then y, contiguous on z + */ + int yzPlaneSize = shape.y * shape.z; /* nb of elem in 2d yz plane */ + int x = index / yzPlaneSize; + int yzIndex = index - x * yzPlaneSize; /* index in 2d yz plane */ + int y = yzIndex / shape.z; + int z = yzIndex - y * shape.z; + return ivec3(x, y, z); + } + + void main(void) + { + vNormValue = clamp((value - valRange.x) / (valRange.y - valRange.x), + 0.0, 1.0); + + bool isValueInRange = value >= valRange.x && value <= valRange.y; + if (isValueInRange) { + /* Retrieve 3D position from gridIndex */ + vec3 coords = vec3(coordsFromIndex(gl_VertexID, gridDims)); + vec3 position = coords / max(vec3(gridDims) - 1.0, 1.0); + gl_Position = matrix * vec4(position, 1.0); + vCameraPosition = transformMat * vec4(position, 1.0); + } else { + gl_Position = vec4(2.0, 0.0, 0.0, 1.0); /* Get clipped */ + vCameraPosition = vec4(0.0, 0.0, 0.0, 0.0); + } + + gl_PointSize = size; + } + """, + string.Template(""" + #version 130 + + in vec4 vCameraPosition; + in float vNormValue; + out vec4 fragColor; + + $clippingDecl + + void main(void) + { + $clippingCall(vCameraPosition); + + fragColor = vec4(0.5 * vNormValue + 0.5, 0.0, 0.0, 1.0); + } + """)) + + _ATTR_INFO = { + 'value': {'dims': (1, 2), 'lastDim': (1,)}, + 'size': {'dims': (1, 2), 'lastDim': (1,)} + } + + # TODO Add colormap, shape? + # TODO could also use a texture to store values + + def __init__(self, values=0., shape=None, sizes=1., indices=None, + minValue=None, maxValue=None): + if isinstance(values, collections.Iterable): + values = numpy.array(values, copy=False) + + # Test if gl_VertexID will overflow + assert values.size < numpy.iinfo(numpy.int32).max + + self._shape = values.shape + values = values.ravel() # 1D to add as a 1D vertex attribute + + else: + assert shape is not None + self._shape = tuple(shape) + + assert len(self._shape) in (1, 2, 3) + + super(GridPoints, self).__init__('points', indices, + value=values, + size=sizes) + + data = self.getAttribute('value', copy=False) + self._minValue = data.min() if minValue is None else minValue + self._maxValue = data.max() if maxValue is None else maxValue + + minValue = event.notifyProperty('_minValue') + maxValue = event.notifyProperty('_maxValue') + + def _bounds(self, dataBounds=False): + # Get bounds from values shape + bounds = numpy.zeros((2, 3), dtype=numpy.float32) + bounds[1, :] = self._shape + bounds[1, :] -= 1 + return bounds + + def renderGL2(self, ctx): + fragment = self._shaders[1].substitute( + clippingDecl=ctx.clipper.fragDecl, + clippingCall=ctx.clipper.fragCall) + prog = ctx.glCtx.prog(self._shaders[0], fragment) + prog.use() + + gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2 + gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2 + # gl.glEnable(gl.GL_PROGRAM_POINT_SIZE) + + prog.setUniformMatrix('matrix', ctx.objectToNDC.matrix) + prog.setUniformMatrix('transformMat', + ctx.objectToCamera.matrix, + safe=True) + + ctx.clipper.setupProgram(ctx, prog) + + gl.glUniform3i(prog.uniforms['gridDims'], + self._shape[2] if len(self._shape) == 3 else 1, + self._shape[1] if len(self._shape) >= 2 else 1, + self._shape[0]) + + gl.glUniform2f(prog.uniforms['valRange'], self.minValue, self.maxValue) + + self._draw(prog, nbVertices=reduce(lambda a, b: a * b, self._shape)) + + +# Spheres ##################################################################### + +class Spheres(Geometry): + """A set of spheres. + + Spheres are rendered as circles using points. + This brings some limitations: + - Do not support non-uniform scaling. + - Assume the projection keeps ratio. + - Do not render distorion by perspective projection. + - If the sphere center is clipped, the whole sphere is not displayed. + """ + # TODO check those links + # Accounting for perspective projection + # http://iquilezles.org/www/articles/sphereproj/sphereproj.htm + + # Michael Mara and Morgan McGuire. + # 2D Polyhedral Bounds of a Clipped, Perspective-Projected 3D Sphere + # Journal of Computer Graphics Techniques, Vol. 2, No. 2, 2013. + # http://jcgt.org/published/0002/02/05/paper.pdf + # https://research.nvidia.com/publication/2d-polyhedral-bounds-clipped-perspective-projected-3d-sphere + + # TODO some issues with small scaling and regular grid or due to sampling + + _shaders = (""" + #version 120 + + attribute vec3 position; + attribute vec4 color; + attribute float radius; + + uniform mat4 transformMat; + uniform mat4 projMat; + uniform vec2 screenSize; + + varying vec4 vCameraPosition; + varying vec3 vPosition; + varying vec4 vColor; + varying float vViewDepth; + varying float vViewRadius; + + void main(void) + { + vCameraPosition = transformMat * vec4(position, 1.0); + gl_Position = projMat * vCameraPosition; + + vPosition = gl_Position.xyz / gl_Position.w; + + /* From object space radius to view space diameter. + * Do not support non-uniform scaling */ + vec4 viewSizeVector = transformMat * vec4(2.0 * radius, 0.0, 0.0, 0.0); + float viewSize = length(viewSizeVector.xyz); + + /* Convert to pixel size at the xy center of the view space */ + vec4 projSize = projMat * vec4(0.5 * viewSize, 0.0, + vCameraPosition.z, vCameraPosition.w); + gl_PointSize = max(1.0, screenSize[0] * projSize.x / projSize.w); + + vColor = color; + vViewRadius = 0.5 * viewSize; + vViewDepth = vCameraPosition.z; + } + """, + string.Template(""" + # version 120 + + uniform mat4 projMat; + + varying vec4 vCameraPosition; + varying vec3 vPosition; + varying vec4 vColor; + varying float vViewDepth; + varying float vViewRadius; + + $clippingDecl + $lightingFunction + + void main(void) + { + $clippingCall(vCameraPosition); + + /* Get normal from point coords */ + vec3 normal; + normal.xy = 2.0 * gl_PointCoord - vec2(1.0); + normal.y *= -1.0; /*Invert y to match NDC orientation*/ + float sqLength = dot(normal.xy, normal.xy); + if (sqLength > 1.0) { /* Length -> out of sphere */ + discard; + } + normal.z = sqrt(1.0 - sqLength); + + /*Lighting performed in NDC*/ + /*TODO update this when lighting changed*/ + //XXX vec3 position = vPosition + vViewRadius * normal; + gl_FragColor = $lightingCall(vColor, vPosition, normal); + + /*Offset depth*/ + float viewDepth = vViewDepth + vViewRadius * normal.z; + vec2 clipZW = viewDepth * projMat[2].zw + projMat[3].zw; + gl_FragDepth = 0.5 * (clipZW.x / clipZW.y) + 0.5; + } + """)) + + _ATTR_INFO = { + 'position': {'dims': (2, ), 'lastDim': (2, 3, 4)}, + 'radius': {'dims': (1, 2), 'lastDim': (1, )}, + 'color': {'dims': (1, 2), 'lastDim': (3, 4)}, + } + + def __init__(self, positions, radius=1., colors=(1., 1., 1., 1.)): + self.__bounds = None + super(Spheres, self).__init__('points', None, + position=positions, + radius=radius, + color=colors) + + def renderGL2(self, ctx): + fragment = self._shaders[1].substitute( + clippingDecl=ctx.clipper.fragDecl, + clippingCall=ctx.clipper.fragCall, + lightingFunction=ctx.viewport.light.fragmentDef, + lightingCall=ctx.viewport.light.fragmentCall) + prog = ctx.glCtx.prog(self._shaders[0], fragment) + prog.use() + + ctx.viewport.light.setupProgram(ctx, prog) + + gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2 + gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2 + # gl.glEnable(gl.GL_PROGRAM_POINT_SIZE) + + prog.setUniformMatrix('projMat', ctx.projection.matrix) + prog.setUniformMatrix('transformMat', + ctx.objectToCamera.matrix, + safe=True) + + ctx.clipper.setupProgram(ctx, prog) + + gl.glUniform2f(prog.uniforms['screenSize'], *ctx.viewport.size) + + self._draw(prog) + + def _bounds(self, dataBounds=False): + if self.__bounds is None: + self.__bounds = numpy.zeros((2, 3), dtype=numpy.float32) + # Support vertex with to 2 to 4 coordinates + positions = self._attributes['position'] + radius = self._attributes['radius'] + self.__bounds[0, :positions.shape[1]] = \ + (positions - radius).min(axis=0)[:3] + self.__bounds[1, :positions.shape[1]] = \ + (positions + radius).max(axis=0)[:3] + return self.__bounds.copy() + + +# Meshes ###################################################################### + +class Mesh3D(Geometry): + """A conventional 3D mesh""" + + _shaders = (""" + attribute vec3 position; + attribute vec3 normal; + attribute vec4 color; + + uniform mat4 matrix; + uniform mat4 transformMat; + //uniform mat3 matrixInvTranspose; + + varying vec4 vCameraPosition; + varying vec3 vPosition; + varying vec3 vNormal; + varying vec4 vColor; + + void main(void) + { + vCameraPosition = transformMat * vec4(position, 1.0); + //vNormal = matrixInvTranspose * normalize(normal); + vPosition = position; + vNormal = normal; + vColor = color; + gl_Position = matrix * vec4(position, 1.0); + } + """, + string.Template(""" + varying vec4 vCameraPosition; + varying vec3 vPosition; + varying vec3 vNormal; + varying vec4 vColor; + + $clippingDecl + $lightingFunction + + void main(void) + { + $clippingCall(vCameraPosition); + + gl_FragColor = $lightingCall(vColor, vPosition, vNormal); + } + """)) + + def __init__(self, + positions, + colors, + normals=None, + mode='triangles', + indices=None): + assert mode in self._TRIANGLE_MODES + super(Mesh3D, self).__init__(mode, indices, + position=positions, + normal=normals, + color=colors) + + self._culling = None + + @property + def culling(self): + """Face culling (str) + + One of 'back', 'front' or None. + """ + return self._culling + + @culling.setter + def culling(self, culling): + assert culling in ('back', 'front', None) + if culling != self._culling: + self._culling = culling + self.notify() + + def renderGL2(self, ctx): + isnormals = 'normal' in self._attributes + if isnormals: + fragLightFunction = ctx.viewport.light.fragmentDef + else: + fragLightFunction = ctx.viewport.light.fragmentShaderFunctionNoop + + fragment = self._shaders[1].substitute( + clippingDecl=ctx.clipper.fragDecl, + clippingCall=ctx.clipper.fragCall, + lightingFunction=fragLightFunction, + lightingCall=ctx.viewport.light.fragmentCall) + prog = ctx.glCtx.prog(self._shaders[0], fragment) + prog.use() + + if isnormals: + ctx.viewport.light.setupProgram(ctx, prog) + + if self.culling is not None: + cullFace = gl.GL_FRONT if self.culling == 'front' else gl.GL_BACK + gl.glCullFace(cullFace) + gl.glEnable(gl.GL_CULL_FACE) + + prog.setUniformMatrix('matrix', ctx.objectToNDC.matrix) + prog.setUniformMatrix('transformMat', + ctx.objectToCamera.matrix, + safe=True) + + ctx.clipper.setupProgram(ctx, prog) + + self._draw(prog) + + if self.culling is not None: + gl.glDisable(gl.GL_CULL_FACE) + + +# Group ###################################################################### + +# TODO lighting, clipping as groups? +# group composition? + +class GroupDepthOffset(core.Group): + """A group using 2-pass rendering and glDepthRange to avoid Z-fighting""" + + def __init__(self, children=(), epsilon=None): + super(GroupDepthOffset, self).__init__(children) + self._epsilon = epsilon + self.isDepthRangeOn = True + + def prepareGL2(self, ctx): + if self._epsilon is None: + depthbits = gl.glGetInteger(gl.GL_DEPTH_BITS) + self._epsilon = 1. / (1 << (depthbits - 1)) + + def renderGL2(self, ctx): + if self.isDepthRangeOn: + self._renderGL2WithDepthRange(ctx) + else: + super(GroupDepthOffset, self).renderGL2(ctx) + + def _renderGL2WithDepthRange(self, ctx): + # gl.glDepthFunc(gl.GL_LESS) + with gl.enabled(gl.GL_CULL_FACE): + gl.glCullFace(gl.GL_BACK) + for child in self.children: + gl.glColorMask( + gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE) + gl.glDepthMask(gl.GL_TRUE) + gl.glDepthRange(self._epsilon, 1.) + + child.render(ctx) + + gl.glColorMask( + gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE) + gl.glDepthMask(gl.GL_FALSE) + gl.glDepthRange(0., 1. - self._epsilon) + + child.render(ctx) + + gl.glCullFace(gl.GL_FRONT) + for child in reversed(self.children): + gl.glColorMask( + gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE) + gl.glDepthMask(gl.GL_TRUE) + gl.glDepthRange(self._epsilon, 1.) + + child.render(ctx) + + gl.glColorMask( + gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE) + gl.glDepthMask(gl.GL_FALSE) + gl.glDepthRange(0., 1. - self._epsilon) + + child.render(ctx) + + gl.glDepthMask(gl.GL_TRUE) + gl.glDepthRange(0., 1.) + # gl.glDepthFunc(gl.GL_LEQUAL) + # TODO use epsilon for all rendering? + # TODO issue with picking in depth buffer! + + +class GroupBBox(core.PrivateGroup): + """A group displaying a bounding box around the children.""" + + def __init__(self, children=(), color=(1., 1., 1., 1.)): + super(GroupBBox, self).__init__() + self._group = core.Group(children) + + self._boxTransforms = transform.TransformList( + (transform.Translate(), transform.Scale())) + + self._boxWithAxes = BoxWithAxes(color) + self._boxWithAxes.smooth = False + self._boxWithAxes.transforms = self._boxTransforms + + self._children = [self._boxWithAxes, self._group] + + def _updateBoxAndAxes(self): + """Update bbox and axes position and size according to children.""" + bounds = self._group.bounds(dataBounds=True) + if bounds is not None: + origin = bounds[0] + scale = [(d if d != 0. else 1.) for d in bounds[1] - bounds[0]] + else: + origin, scale = (0., 0., 0.), (1., 1., 1.) + + self._boxTransforms[0].translation = origin + self._boxTransforms[1].scale = scale + + def _bounds(self, dataBounds=False): + self._updateBoxAndAxes() + return super(GroupBBox, self)._bounds(dataBounds) + + def prepareGL2(self, ctx): + self._updateBoxAndAxes() + super(GroupBBox, self).prepareGL2(ctx) + + # Give access to _group children + + @property + def children(self): + return self._group.children + + @children.setter + def children(self, iterable): + self._group.children = iterable + + # Give access to box color + + @property + def color(self): + """The RGBA color to use for the box: 4 float in [0, 1]""" + return self._boxWithAxes.color + + @color.setter + def color(self, color): + self._boxWithAxes.color = color + + +# Clipping Plane ############################################################## + +class ClipPlane(PlaneInGroup): + """A clipping plane attached to a box""" + + def renderGL2(self, ctx): + super(ClipPlane, self).renderGL2(ctx) + + if self.visible: + # Set-up clipping plane for following brothers + + # No need of perspective divide, no projection + point = ctx.objectToCamera.transformPoint(self.plane.point, + perspectiveDivide=False) + normal = ctx.objectToCamera.transformNormal(self.plane.normal) + ctx.setClipPlane(point, normal) + + def postRender(self, ctx): + if self.visible: + # Disable clip planes + ctx.setClipPlane() diff --git a/silx/gui/plot3d/scene/setup.py b/silx/gui/plot3d/scene/setup.py new file mode 100644 index 0000000..ff4c0a6 --- /dev/null +++ b/silx/gui/plot3d/scene/setup.py @@ -0,0 +1,41 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "25/07/2016" + +from numpy.distutils.misc_util import Configuration + + +def configuration(parent_package='', top_path=None): + config = Configuration('scene', parent_package, top_path) + config.add_subpackage('test') + return config + + +if __name__ == "__main__": + from numpy.distutils.core import setup + + setup(configuration=configuration) diff --git a/silx/gui/plot3d/scene/test/__init__.py b/silx/gui/plot3d/scene/test/__init__.py new file mode 100644 index 0000000..fc4621e --- /dev/null +++ b/silx/gui/plot3d/scene/test/__init__.py @@ -0,0 +1,43 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "25/07/2016" + + +import unittest + +from .test_transform import suite as test_transform_suite +from .test_utils import suite as test_utils_suite + + +def suite(): + testsuite = unittest.TestSuite() + testsuite.addTest(test_transform_suite()) + testsuite.addTest(test_utils_suite()) + return testsuite diff --git a/silx/gui/plot3d/scene/test/test_transform.py b/silx/gui/plot3d/scene/test/test_transform.py new file mode 100644 index 0000000..9ea0af1 --- /dev/null +++ b/silx/gui/plot3d/scene/test/test_transform.py @@ -0,0 +1,91 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "05/01/2017" + + +import numpy +import unittest + +from silx.gui.plot3d.scene import transform + + +class TestTransformList(unittest.TestCase): + + def assertSameArrays(self, a, b): + return self.assertTrue(numpy.allclose(a, b, atol=1e-06)) + + def testTransformList(self): + """Minimalistic test of TransformList""" + transforms = transform.TransformList() + refmatrix = numpy.identity(4, dtype=numpy.float32) + self.assertSameArrays(refmatrix, transforms.matrix) + + # Append translate + transforms.append(transform.Translate(1., 1., 1.)) + refmatrix = numpy.array(((1., 0., 0., 1.), + (0., 1., 0., 1.), + (0., 0., 1., 1.), + (0., 0., 0., 1.)), dtype=numpy.float32) + self.assertSameArrays(refmatrix, transforms.matrix) + + # Extend scale + transforms.extend([transform.Scale(0.1, 2., 1.)]) + refmatrix = numpy.dot(refmatrix, + numpy.array(((0.1, 0., 0., 0.), + (0., 2., 0., 0.), + (0., 0., 1., 0.), + (0., 0., 0., 1.)), + dtype=numpy.float32)) + self.assertSameArrays(refmatrix, transforms.matrix) + + # Insert rotate + transforms.insert(0, transform.Rotate(360.)) + self.assertSameArrays(refmatrix, transforms.matrix) + + # Update translate and check for listener called + self._callCount = 0 + + def listener(source): + self._callCount += 1 + transforms.addListener(listener) + + transforms[1].tx += 1 + self.assertEqual(self._callCount, 1) + + +def suite(): + testsuite = unittest.TestSuite() + testsuite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestTransformList)) + return testsuite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot3d/scene/test/test_utils.py b/silx/gui/plot3d/scene/test/test_utils.py new file mode 100644 index 0000000..65c2407 --- /dev/null +++ b/silx/gui/plot3d/scene/test/test_utils.py @@ -0,0 +1,275 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "25/07/2016" + + +import unittest +from silx.test.utils import ParametricTestCase + +import numpy + +from silx.gui.plot3d.scene import utils + + +# angleBetweenVectors ######################################################### + +class TestAngleBetweenVectors(ParametricTestCase): + + TESTS = { # name: (refvector, vectors, norm, refangles) + 'single vector': + ((1., 0., 0.), (1., 0., 0.), (0., 0., 1.), 0.), + 'single vector, no norm': + ((1., 0., 0.), (1., 0., 0.), None, 0.), + + 'with orthogonal norm': + ((1., 0., 0.), + ((1., 0., 0.), (0., 1., 0.), (-1., 0., 0.), (0., -1., 0.)), + (0., 0., 1.), + (0., 90., 180., 270.)), + + 'with coplanar norm': # = similar to no norm + ((1., 0., 0.), + ((1., 0., 0.), (0., 1., 0.), (-1., 0., 0.), (0., -1., 0.)), + (1., 0., 0.), + (0., 90., 180., 90.)), + + 'without norm': + ((1., 0., 0.), + ((1., 0., 0.), (0., 1., 0.), (-1., 0., 0.), (0., -1., 0.)), + None, + (0., 90., 180., 90.)), + + 'not unit vectors': + ((2., 2., 0.), ((1., 1., 0.), (1., -1., 0.)), None, (0., 90.)), + } + + def testAngleBetweenVectorsFunction(self): + for name, params in self.TESTS.items(): + refvector, vectors, norm, refangles = params + with self.subTest(name): + refangles = numpy.radians(refangles) + + refvector = numpy.array(refvector) + vectors = numpy.array(vectors) + if norm is not None: + norm = numpy.array(norm) + + testangles = utils.angleBetweenVectors( + refvector, vectors, norm) + + self.assertTrue( + numpy.allclose(testangles, refangles, atol=1e-5)) + + +# Plane ####################################################################### + +class AssertNotificationContext(object): + """Context that checks if an event.Notifier is sending events.""" + + def __init__(self, notifier, count=1): + """Initializer. + + :param event.Notifier notifier: The notifier to test. + :param int count: The expected number of calls. + """ + self._notifier = notifier + self._callCount = None + self._count = count + + def __enter__(self): + self._callCount = 0 + self._notifier.addListener(self._callback) + + def __exit__(self, exc_type, exc_value, traceback): + # Do not return True so exceptions are propagated + self._notifier.removeListener(self._callback) + assert self._callCount == self._count + self._callCount = None + + def _callback(self, *args, **kwargs): + self._callCount += 1 + + +class TestPlaneParameters(ParametricTestCase): + """Test Plane.parameters read/write and notifications.""" + + PARAMETERS = { + 'unit normal': (1., 0., 0., 1.), + 'not unit normal': (1., 1., 0., 1.), + 'd = 0': (1., 0., 0., 0.) + } + + def testParameters(self): + """Check parameters read/write and notification.""" + plane = utils.Plane() + + for name, parameters in self.PARAMETERS.items(): + with self.subTest(name, parameters=parameters): + with AssertNotificationContext(plane): + plane.parameters = parameters + + # Plane parameters are converted to have a unit normal + normparams = parameters / numpy.linalg.norm(parameters[:3]) + self.assertTrue(numpy.allclose(plane.parameters, normparams)) + + ZEROS_PARAMETERS = ( + (0., 0., 0., 0.), + (0., 0., 0., 1.) + ) + + ZEROS = 0., 0., 0., 0. + + def testParametersNoPlane(self): + """Test Plane.parameters with ||normal|| == 0 .""" + plane = utils.Plane() + plane.parameters = self.ZEROS + + for parameters in self.ZEROS_PARAMETERS: + with self.subTest(parameters=parameters): + with AssertNotificationContext(plane, count=0): + plane.parameters = parameters + self.assertTrue( + numpy.allclose(plane.parameters, self.ZEROS, 0., 0.)) + + +# unindexArrays ############################################################### + +class TestUnindexArrays(ParametricTestCase): + """Test unindexArrays function.""" + + def testBasicModes(self): + """Test for modes: points, lines and triangles""" + indices = numpy.array((1, 2, 0)) + arrays = (numpy.array((0., 1., 2.)), + numpy.array(((0, 0), (1, 1), (2, 2)))) + refresults = (numpy.array((1., 2., 0.)), + numpy.array(((1, 1), (2, 2), (0, 0)))) + + for mode in ('points', 'lines', 'triangles'): + with self.subTest(mode=mode): + testresults = utils.unindexArrays(mode, indices, *arrays) + for ref, test in zip(refresults, testresults): + self.assertTrue(numpy.equal(ref, test).all()) + + def testPackedLines(self): + """Test for modes: line_strip, loop""" + indices = numpy.array((1, 2, 0)) + arrays = (numpy.array((0., 1., 2.)), + numpy.array(((0, 0), (1, 1), (2, 2)))) + results = { + 'line_strip': ( + numpy.array((1., 2., 2., 0.)), + numpy.array(((1, 1), (2, 2), (2, 2), (0, 0)))), + 'loop': ( + numpy.array((1., 2., 2., 0., 0., 1.)), + numpy.array(((1, 1), (2, 2), (2, 2), (0, 0), (0, 0), (1, 1)))), + } + + for mode, refresults in results.items(): + with self.subTest(mode=mode): + testresults = utils.unindexArrays(mode, indices, *arrays) + for ref, test in zip(refresults, testresults): + self.assertTrue(numpy.equal(ref, test).all()) + + def testPackedTriangles(self): + """Test for modes: triangle_strip, fan""" + indices = numpy.array((1, 2, 0, 3)) + arrays = (numpy.array((0., 1., 2., 3.)), + numpy.array(((0, 0), (1, 1), (2, 2), (3, 3)))) + results = { + 'triangle_strip': ( + numpy.array((1., 2., 0., 2., 0., 3.)), + numpy.array(((1, 1), (2, 2), (0, 0), (2, 2), (0, 0), (3, 3)))), + 'fan': ( + numpy.array((1., 2., 0., 1., 0., 3.)), + numpy.array(((1, 1), (2, 2), (0, 0), (1, 1), (0, 0), (3, 3)))), + } + + for mode, refresults in results.items(): + with self.subTest(mode=mode): + testresults = utils.unindexArrays(mode, indices, *arrays) + for ref, test in zip(refresults, testresults): + self.assertTrue(numpy.equal(ref, test).all()) + + def testBadIndices(self): + """Test with negative indices and indices higher than array length""" + arrays = numpy.array((0, 1)), numpy.array((0, 1, 2)) + + # negative indices + with self.assertRaises(AssertionError): + utils.unindexArrays('points', (-1, 0), *arrays) + + # Too high indices + with self.assertRaises(AssertionError): + utils.unindexArrays('points', (0, 10), *arrays) + + +# triangleNormals ############################################################# + +class TestTriangleNormals(ParametricTestCase): + """Test triangleNormals function.""" + + def test(self): + """Test for modes: points, lines and triangles""" + positions = numpy.array( + ((0., 0., 0.), (1., 0., 0.), (0., 1., 0.), # normal = Z + (1., 1., 1.), (1., 2., 3.), (4., 5., 6.), # Random triangle + # Degenerated triangles: + (0., 0., 0.), (1., 0., 0.), (2., 0., 0.), # Colinear points + (1., 1., 1.), (1., 1., 1.), (1., 1., 1.), # All same point + ), + dtype='float32') + + normals = numpy.array( + ((0., 0., 1.), + (-0.40824829, 0.81649658, -0.40824829), + (0., 0., 0.), + (0., 0., 0.)), + dtype='float32') + + testnormals = utils.trianglesNormal(positions) + self.assertTrue(numpy.allclose(testnormals, normals)) + + +# suite ####################################################################### + +def suite(): + testsuite = unittest.TestSuite() + for test in (TestAngleBetweenVectors, + TestPlaneParameters, + TestUnindexArrays, + TestTriangleNormals): + testsuite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(test)) + return testsuite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot3d/scene/text.py b/silx/gui/plot3d/scene/text.py new file mode 100644 index 0000000..903fc21 --- /dev/null +++ b/silx/gui/plot3d/scene/text.py @@ -0,0 +1,534 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Primitive displaying a text field in the scene.""" + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "17/10/2016" + + +import logging +import numpy + +from silx.gui.plot.Colors import rgba + +from ... import _glutils +from ..._glutils import gl + +from ..._glutils import font as _font +from ...plot._utils import ticklayout + +from . import event, primitives, core, transform + + +_logger = logging.getLogger(__name__) + + +class Font(event.Notifier): + """Description of a font. + + :param str name: Family of the font + :param int size: Size of the font in points + :param int weight: Font weight + :param bool italic: True for italic font, False (default) otherwise + """ + + def __init__(self, name=None, size=-1, weight=-1, italic=False): + self._name = name if name is not None else _font.getDefaultFontFamily() + self._size = size + self._weight = weight + self._italic = italic + super(Font, self).__init__() + + name = event.notifyProperty( + '_name', + doc="""Name of the font (str)""", + converter=str) + + size = event.notifyProperty( + '_size', + doc="""Font size in points (int)""", + converter=int) + + weight = event.notifyProperty( + '_weight', + doc="""Font size in points (int)""", + converter=int) + + italic = event.notifyProperty( + '_italic', + doc="""True for italic (bool)""", + converter=bool) + + +class Text2D(primitives.Geometry): + """Text field as a 2D texture displayed with bill-boarding + + :param str text: Text to display + :param Font font: The font to use + """ + + # Text anchor values + CENTER = 'center' + + LEFT = 'left' + RIGHT = 'right' + + TOP = 'top' + BASELINE = 'baseline' + BOTTOM = 'bottom' + + _ALIGN = LEFT, CENTER, RIGHT + _VALIGN = TOP, BASELINE, CENTER, BOTTOM + + _rasterTextCache = {} + """Internal cache storing already rasterized text""" + # TODO limit cache size and discard least recent used + + def __init__(self, text='', font=None): + self._dirtyTexture = True + self._dirtyAlign = True + self._baselineOffset = 0 + self._text = text + self._font = font if font is not None else Font() + self._foreground = 1., 1., 1., 1. + self._background = 0., 0., 0., 0. + self._overlay = False + self._align = 'left' + self._valign = 'baseline' + self._devicePixelRatio = 1.0 # Store it to check for changes + + self._texture = None + self._textureDirty = True + + super(Text2D, self).__init__( + 'triangle_strip', + copy=False, + # Keep an array for position as it is bound to attr 0 and MUST + # be active and an array at least on Mac OS X + position=numpy.zeros((4, 3), dtype=numpy.float32), + vertexID=numpy.arange(4., dtype=numpy.float32).reshape(4, 1), + offsetInViewportCoords=(0., 0.)) + + @property + def text(self): + """Text displayed by this primitive (str)""" + return self._text + + @text.setter + def text(self, text): + text = str(text) + if text != self._text: + self._dirtyTexture = True + self._text = text + self.notify() + + @property + def font(self): + """Font to use to raster text (Font)""" + return self._font + + @font.setter + def font(self, font): + self._font.removeListener(self._fontChanged) + self._font = font + self._font.addListener(self._fontChanged) + self._fontChanged(self) # Which calls notify and primitive as dirty + + def _fontChanged(self, source): + """Listen for font change""" + self._dirtyTexture = True + self.notify() + + foreground = event.notifyProperty( + '_foreground', doc="""RGBA color of the text: 4 float in [0, 1]""", + converter=rgba) + + background = event.notifyProperty( + '_background', + doc="RGBA background color of the text field: 4 float in [0, 1]", + converter=rgba) + + overlay = event.notifyProperty( + '_overlay', + doc="True to always display text on top of the scene (default: False)", + converter=bool) + + def _setAlign(self, align): + assert align in self._ALIGN + self._align = align + self._dirtyAlign = True + self.notify() + + align = property( + lambda self: self._align, + _setAlign, + doc="""Horizontal anchor position of the text field (str). + + Either 'left' (default), 'center' or 'right'.""") + + def _setVAlign(self, valign): + assert valign in self._VALIGN + self._valign = valign + self._dirtyAlign = True + self.notify() + + valign = property( + lambda self: self._valign, + _setVAlign, + doc="""Vertical anchor position of the text field (str). + + Either 'top', 'baseline' (default), 'center' or 'bottom'""") + + def _raster(self, devicePixelRatio): + """Raster current primitive to a bitmap + + :param float devicePixelRatio: + The ratio between device and device-independent pixels + :return: Corresponding image in grayscale and baseline offset from top + :rtype: (HxW numpy.ndarray of uint8, int) + """ + params = (self.text, + self.font.name, + self.font.size, + self.font.weight, + self.font.italic, + devicePixelRatio) + + if params not in self._rasterTextCache: # Add to cache + self._rasterTextCache[params] = _font.rasterText(*params) + + array, offset = self._rasterTextCache[params] + return array.copy(), offset + + def _bounds(self, dataBounds=False): + return None + + def prepareGL2(self, context): + # Check if devicePixelRatio has changed since last rendering + devicePixelRatio = context.glCtx.devicePixelRatio + if self._devicePixelRatio != devicePixelRatio: + self._devicePixelRatio = devicePixelRatio + self._dirtyTexture = True + + if self._dirtyTexture: + self._dirtyTexture = False + + if self._texture is not None: + self._texture.discard() + self._texture = None + self._baselineOffset = 0 + + if self.text: + image, self._baselineOffset = self._raster( + self._devicePixelRatio) + self._texture = _glutils.Texture( + gl.GL_R8, image, gl.GL_RED, + minFilter=gl.GL_NEAREST, + magFilter=gl.GL_NEAREST, + wrap=gl.GL_CLAMP_TO_EDGE) + self._dirtyAlign = True # To force update of offset + + if self._dirtyAlign: + self._dirtyAlign = False + + if self._texture is not None: + height, width = self._texture.shape + + if self._align == 'left': + ox = 0. + elif self._align == 'center': + ox = - width // 2 + elif self._align == 'right': + ox = - width + else: + _logger.error("Unsupported align: %s", self._align) + ox = 0. + + if self._valign == 'top': + oy = 0. + elif self._valign == 'baseline': + oy = self._baselineOffset + elif self._valign == 'center': + oy = height // 2 + elif self._valign == 'bottom': + oy = height + else: + _logger.error("Unsupported valign: %s", self._valign) + oy = 0. + + offsets = (ox, oy) + numpy.array( + ((0., 0.), (width, 0.), (0., -height), (width, -height)), + dtype=numpy.float32) + self.setAttribute('offsetInViewportCoords', offsets) + + super(Text2D, self).prepareGL2(context) + + def renderGL2(self, context): + if not self.text: + return # Nothing to render + + program = context.glCtx.prog(*self._shaders) + program.use() + + program.setUniformMatrix('matrix', context.objectToNDC.matrix) + gl.glUniform2f( + program.uniforms['viewportSize'], *context.viewport.size) + gl.glUniform4f(program.uniforms['foreground'], *self.foreground) + gl.glUniform4f(program.uniforms['background'], *self.background) + gl.glUniform1i(program.uniforms['texture'], self._texture.texUnit) + gl.glUniform1i(program.uniforms['isOverlay'], + 1 if self._overlay else 0) + + self._texture.bind() + + if not self._overlay or not gl.glGetBoolean(gl.GL_DEPTH_TEST): + self._draw(program) + else: # overlay and depth test currently enabled + gl.glDisable(gl.GL_DEPTH_TEST) + self._draw(program) + gl.glEnable(gl.GL_DEPTH_TEST) + + # TODO texture atlas + viewportSize as attribute to chain text rendering + + _shaders = ( + """ + attribute vec3 position; + attribute vec2 offsetInViewportCoords; /* Offset in pixels (y upward) */ + attribute float vertexID; /* Index of rectangle corner */ + + uniform mat4 matrix; + uniform vec2 viewportSize; /* Width, height of the viewport */ + uniform int isOverlay; + + varying vec2 texCoords; + + void main(void) + { + vec4 clipPos = matrix * vec4(position, 1.0); + vec4 ndcPos = clipPos / clipPos.w; /* Perspective divide */ + + /* Align ndcPos with pixels in viewport-like coords (origin useless) */ + vec2 viewportPos = floor((ndcPos.xy + vec2(1.0, 1.0)) * 0.5 * viewportSize); + + /* Apply offset in viewport coords */ + viewportPos += offsetInViewportCoords; + + /* Convert back to NDC */ + vec2 pointPos = 2.0 * viewportPos / viewportSize - vec2(1.0, 1.0); + float z = (isOverlay != 0) ? -1.0 : ndcPos.z; + gl_Position = vec4(pointPos, z, 1.0); + + /* Index : texCoords: + * 0: (0., 0.) + * 1: (1., 0.) + * 2: (0., 1.) + * 3: (1., 1.) + */ + texCoords = vec2(vertexID == 0.0 || vertexID == 2.0 ? 0.0 : 1.0, + vertexID < 1.5 ? 0.0 : 1.0); + } + """, # noqa + + """ + varying vec2 texCoords; + + uniform vec4 foreground; + uniform vec4 background; + uniform sampler2D texture; + + void main(void) + { + float value = texture2D(texture, texCoords).r; + + if (background.a != 0.0) { + gl_FragColor = mix(background, foreground, value); + } else { + gl_FragColor = foreground; + gl_FragColor.a *= value; + if (gl_FragColor.a <= 0.01) { + discard; + } + } + } + """) + + +class LabelledAxes(primitives.GroupBBox): + """A group displaying a bounding box with axes labels around its children. + """ + + def __init__(self): + super(LabelledAxes, self).__init__() + self._ticksForBounds = None + + self._font = Font() + + # TODO offset labels from anchor in pixels + + self._xlabel = Text2D(font=self._font) + self._xlabel.align = 'center' + self._xlabel.transforms = [self._boxTransforms, + transform.Translate(tx=0.5)] + self._children.append(self._xlabel) + + self._ylabel = Text2D(font=self._font) + self._ylabel.align = 'center' + self._ylabel.transforms = [self._boxTransforms, + transform.Translate(ty=0.5)] + self._children.append(self._ylabel) + + self._zlabel = Text2D(font=self._font) + self._zlabel.align = 'center' + self._zlabel.transforms = [self._boxTransforms, + transform.Translate(tz=0.5)] + self._children.append(self._zlabel) + + self._tickLines = primitives.Lines( # Init tick lines with dummy pos + positions=((0., 0., 0.), (0., 0., 0.)), + mode='lines') + self._tickLines.visible = False + self._children.append(self._tickLines) + + self._tickLabels = core.Group() + self._children.append(self._tickLabels) + + @property + def font(self): + """Font of axes text labels (Font)""" + return self._font + + @font.setter + def font(self, font): + self._font = font + self._xlabel.font = font + self._ylabel.font = font + self._zlabel.font = font + for label in self._tickLabels.children: + label.font = font + + @property + def xlabel(self): + """Text label of the X axis (str)""" + return self._xlabel.text + + @xlabel.setter + def xlabel(self, text): + self._xlabel.text = text + + @property + def ylabel(self): + """Text label of the Y axis (str)""" + return self._ylabel.text + + @ylabel.setter + def ylabel(self, text): + self._ylabel.text = text + + @property + def zlabel(self): + """Text label of the Z axis (str)""" + return self._zlabel.text + + @zlabel.setter + def zlabel(self, text): + self._zlabel.text = text + + def _updateTicks(self): + """Check if ticks need update and update them if needed.""" + bounds = self._group.bounds(transformed=False, dataBounds=True) + if bounds is None: # No content + if self._ticksForBounds is not None: + self._ticksForBounds = None + self._tickLines.visible = False + self._tickLabels.children = [] # Reset previous labels + + elif (self._ticksForBounds is None or + not numpy.all(numpy.equal(bounds, self._ticksForBounds))): + self._ticksForBounds = bounds + + # Update ticks + # TODO make ticks having a constant length on the screen + ticklength = numpy.abs(bounds[1] - bounds[0]) / 20. + + xticks, xlabels = ticklayout.ticks(*bounds[:, 0]) + yticks, ylabels = ticklayout.ticks(*bounds[:, 1]) + zticks, zlabels = ticklayout.ticks(*bounds[:, 2]) + + # Update tick lines + coords = numpy.empty( + ((len(xticks) + len(yticks) + len(zticks)), 4, 3), + dtype=numpy.float32) + coords[:, :, :] = bounds[0, :] # account for offset from origin + + xcoords = coords[:len(xticks)] + xcoords[:, :, 0] = numpy.asarray(xticks)[:, numpy.newaxis] + xcoords[:, 1, 1] += ticklength[1] # X ticks on XY plane + xcoords[:, 3, 2] += ticklength[2] # X ticks on XZ plane + + ycoords = coords[len(xticks):len(xticks) + len(yticks)] + ycoords[:, :, 1] = numpy.asarray(yticks)[:, numpy.newaxis] + ycoords[:, 1, 0] += ticklength[0] # Y ticks on XY plane + ycoords[:, 3, 2] += ticklength[2] # Y ticks on YZ plane + + zcoords = coords[len(xticks) + len(yticks):] + zcoords[:, :, 2] = numpy.asarray(zticks)[:, numpy.newaxis] + zcoords[:, 1, 0] += ticklength[0] # Z ticks on XZ plane + zcoords[:, 3, 1] += ticklength[1] # Z ticks on YZ plane + + self._tickLines.setAttribute('position', coords.reshape(-1, 3)) + self._tickLines.visible = True + + # Update labels + offsets = bounds[0] - ticklength + labels = [] + for tick, label in zip(xticks, xlabels): + text = Text2D(text=label, font=self.font) + text.align = 'center' + text.transforms = [transform.Translate( + tx=tick, ty=offsets[1], tz=offsets[2])] + labels.append(text) + + for tick, label in zip(yticks, ylabels): + text = Text2D(text=label, font=self.font) + text.align = 'center' + text.transforms = [transform.Translate( + tx=offsets[0], ty=tick, tz=offsets[2])] + labels.append(text) + + for tick, label in zip(zticks, zlabels): + text = Text2D(text=label, font=self.font) + text.align = 'center' + text.transforms = [transform.Translate( + tx=offsets[0], ty=offsets[1], tz=tick)] + labels.append(text) + + self._tickLabels.children = labels # Reset previous labels + + def prepareGL2(self, context): + self._updateTicks() + super(LabelledAxes, self).prepareGL2(context) diff --git a/silx/gui/plot3d/scene/transform.py b/silx/gui/plot3d/scene/transform.py new file mode 100644 index 0000000..71a6b74 --- /dev/null +++ b/silx/gui/plot3d/scene/transform.py @@ -0,0 +1,968 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides 4x4 matrix operation and classes to handle them.""" + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "25/07/2016" + + +import itertools +import numpy + +from . import event + + +# Functions ################################################################### + +# Projections + +def mat4LookAtDir(position, direction, up): + """Creates matrix to look in direction from position. + + :param position: Array-like 3 coordinates of the point of view position. + :param direction: Array-like 3 coordinates of the sight direction vector. + :param up: Array-like 3 coordinates of the upward direction + in the image plane. + :returns: Corresponding matrix. + :rtype: numpy.ndarray of shape (4, 4) + """ + assert len(position) == 3 + assert len(direction) == 3 + assert len(up) == 3 + + direction = numpy.array(direction, copy=True, dtype=numpy.float32) + dirnorm = numpy.linalg.norm(direction) + assert dirnorm != 0. + direction /= dirnorm + + side = numpy.cross(direction, + numpy.array(up, copy=False, dtype=numpy.float32)) + sidenorm = numpy.linalg.norm(side) + assert sidenorm != 0. + up = numpy.cross(side / sidenorm, direction) + upnorm = numpy.linalg.norm(up) + assert upnorm != 0. + up /= upnorm + + matrix = numpy.identity(4, dtype=numpy.float32) + matrix[0, :3] = side + matrix[1, :3] = up + matrix[2, :3] = -direction + return numpy.dot(matrix, + mat4Translate(-position[0], -position[1], -position[2])) + + +def mat4LookAt(position, center, up): + """Creates matrix to look at center from position. + + See gluLookAt. + + :param position: Array-like 3 coordinates of the point of view position. + :param center: Array-like 3 coordinates of the center of the scene. + :param up: Array-like 3 coordinates of the upward direction + in the image plane. + :returns: Corresponding matrix. + :rtype: numpy.ndarray of shape (4, 4) + """ + position = numpy.array(position, copy=False, dtype=numpy.float32) + center = numpy.array(center, copy=False, dtype=numpy.float32) + direction = center - position + return mat4LookAtDir(position, direction, up) + + +def mat4Frustum(left, right, bottom, top, near, far): + """Creates a frustum projection matrix. + + See glFrustum. + """ + return numpy.array(( + (2.*near / (right-left), 0., (right+left) / (right-left), 0.), + (0., 2.*near / (top-bottom), (top+bottom) / (top-bottom), 0.), + (0., 0., -(far+near) / (far-near), -2.*far*near / (far-near)), + (0., 0., -1., 0.)), dtype=numpy.float32) + + +def mat4Perspective(fovy, width, height, near, far): + """Creates a perspective projection matrix. + + Similar to gluPerspective. + + :param float fovy: Field of view angle in degrees in the y direction. + :param float width: Width of the viewport. + :param float height: Height of the viewport. + :param float near: Distance to the near plane (strictly positive). + :param float far: Distance to the far plane (strictly positive). + :return: Corresponding matrix. + :rtype: numpy.ndarray of shape (4, 4) + """ + assert fovy != 0 + assert height != 0 + assert width != 0 + assert near > 0. + assert far > near + aspectratio = width / height + f = 1. / numpy.tan(numpy.radians(fovy) / 2.) + return numpy.array(( + (f / aspectratio, 0., 0., 0.), + (0., f, 0., 0.), + (0., 0., (far + near) / (near - far), 2. * far * near / (near - far)), + (0., 0., -1., 0.)), dtype=numpy.float32) + + +def mat4Orthographic(left, right, bottom, top, near, far): + """Creates an orthographic (i.e., parallel) projection matrix. + + See glOrtho. + """ + return numpy.array(( + (2. / (right - left), 0., 0., - (right + left) / (right - left)), + (0., 2. / (top - bottom), 0., - (top + bottom) / (top - bottom)), + (0., 0., -2. / (far - near), - (far + near) / (far - near)), + (0., 0., 0., 1.)), dtype=numpy.float32) + + +# Affine + +def mat4Translate(tx, ty, tz): + """4x4 translation matrix.""" + return numpy.array(( + (1., 0., 0., tx), + (0., 1., 0., ty), + (0., 0., 1., tz), + (0., 0., 0., 1.)), dtype=numpy.float32) + + +def mat4Scale(sx, sy, sz): + """4x4 scale matrix.""" + return numpy.array(( + (sx, 0., 0., 0.), + (0., sy, 0., 0.), + (0., 0., sz, 0.), + (0., 0., 0., 1.)), dtype=numpy.float32) + + +def mat4RotateFromAngleAxis(angle, x=0., y=0., z=1.): + """4x4 rotation matrix from angle and axis. + + :param float angle: The rotation angle in radians. + :param float x: The rotation vector x coordinate. + :param float y: The rotation vector y coordinate. + :param float z: The rotation vector z coordinate. + """ + ca = numpy.cos(angle) + sa = numpy.sin(angle) + return numpy.array(( + ((1.-ca) * x*x + ca, (1.-ca) * x*y - sa*z, (1.-ca) * x*z + sa*y, 0.), + ((1.-ca) * x*y + sa*z, (1.-ca) * y*y + ca, (1.-ca) * y*z - sa*x, 0.), + ((1.-ca) * x*z - sa*y, (1.-ca) * y*z + sa*x, (1.-ca) * z*z + ca, 0.), + (0., 0., 0., 1.)), dtype=numpy.float32) + + +def mat4RotateFromQuaternion(quaternion): + """4x4 rotation matrix from quaternion. + + :param quaternion: Array-like unit quaternion stored as (x, y, z, w) + """ + quaternion = numpy.array(quaternion, copy=True) + quaternion /= numpy.linalg.norm(quaternion) + + qx, qy, qz, qw = quaternion + return numpy.array(( + (1. - 2.*(qy**2 + qz**2), 2.*(qx*qy - qw*qz), 2.*(qx*qz + qw*qy), 0.), + (2.*(qx*qy + qw*qz), 1. - 2.*(qx**2 + qz**2), 2.*(qy*qz - qw*qx), 0.), + (2.*(qx*qz - qw*qy), 2.*(qy*qz + qw*qx), 1. - 2.*(qx**2 + qy**2), 0.), + (0., 0., 0., 1.)), dtype=numpy.float32) + + +def mat4Shear(axis, sx=0., sy=0., sz=0.): + """4x4 shear matrix: Skew two axes relative to a third fixed one. + + shearFactor = tan(shearAngle) + + :param str axis: The axis to keep constant and shear against. + In 'x', 'y', 'z'. + :param float sx: The shear factor for the X axis relative to axis. + :param float sy: The shear factor for the Y axis relative to axis. + :param float sz: The shear factor for the Z axis relative to axis. + """ + assert axis in ('x', 'y', 'z') + + matrix = numpy.identity(4, dtype=numpy.float32) + + # Make the shear column + index = 'xyz'.find(axis) + shearcolumn = numpy.array((sx, sy, sz, 0.), dtype=numpy.float32) + shearcolumn[index] = 1. + matrix[:, index] = shearcolumn + return matrix + + +# Transforms ################################################################## + +class Transform(event.Notifier): + + def __init__(self, static=False): + """Base class for (row-major) 4x4 matrix transforms. + + :param bool static: False (default) to reset cache when changed, + True for static matrices. + """ + super(Transform, self).__init__() + self._matrix = None + self._inverse = None + if not static: + self.addListener(self._changed) # Listening self for changes + + def __repr__(self): + return '%s(%s)' % (self.__class__.__init__, + repr(self.getMatrix(copy=False))) + + def inverse(self): + """Return the Transform of the inverse. + + The returned Transform is static, it is not updated when this + Transform is modified. + + :return: A Transform which is the inverse of this Transform. + """ + return Inverse(self) + + # Matrix + + def _makeMatrix(self): + """Override to build matrix""" + return numpy.identity(4, dtype=numpy.float32) + + def _makeInverse(self): + """Override to build inverse matrix.""" + return numpy.linalg.inv(self.getMatrix(copy=False)) + + def getMatrix(self, copy=True): + """The 4x4 matrix of this transform. + + :param bool copy: True (the default) to get a copy of the matrix, + False to get the internal matrix, do not modify! + :return: 4x4 matrix of this transform. + """ + if self._matrix is None: + self._matrix = self._makeMatrix() + if copy: + return self._matrix.copy() + else: + return self._matrix + + matrix = property(getMatrix, doc="The 4x4 matrix of this transform.") + + def getInverseMatrix(self, copy=False): + """The 4x4 matrix of the inverse of this transform. + + :param bool copy: True (the default) to get a copy of the matrix, + False to get the internal matrix, do not modify! + :return: 4x4 matrix of the inverse of this transform. + """ + if self._inverse is None: + self._inverse = self._makeInverse() + if copy: + return self._inverse.copy() + else: + return self._inverse + + inverseMatrix = property( + getInverseMatrix, + doc="The 4x4 matrix of the inverse of this transform.") + + # Listener + + def _changed(self, source): + """Default self listener reseting matrix cache.""" + self._matrix = None + self._inverse = None + + # Multiplication with vectors + + @staticmethod + def _prepareVector(vector, w): + """Add 4th coordinate (w) to vector if missing.""" + assert len(vector) in (3, 4) + vector = numpy.array(vector, copy=False, dtype=numpy.float32) + if len(vector) == 3: + vector = numpy.append(vector, w) + return vector + + def transformPoint(self, point, direct=True, perspectiveDivide=False): + """Apply the transform to a point. + + If len(point) == 3, apply persective divide if possible. + + :param point: Array-like vector of 3 or 4 coordinates. + :param bool direct: Whether to apply the direct (True, the default) + or inverse (False) transform. + :param bool perspectiveDivide: Whether to apply the perspective divide + (True) or not (False, the default). + :return: The transformed point. + :rtype: numpy.ndarray of same length as point. + """ + if direct: + matrix = self.getMatrix(copy=False) + else: + matrix = self.getInverseMatrix(copy=False) + result = numpy.dot(matrix, self._prepareVector(point, 1.)) + + if perspectiveDivide and result[3] != 0.: + result /= result[3] + + if len(point) == 3: + return result[:3] + else: + return result + + def transformDir(self, direction, direct=True): + """Apply the transform to a direction. + + :param direction: Array-like vector of 3 coordinates. + :param bool direct: Whether to apply the direct (True, the default) + or inverse (False) transform. + :return: The transformed direction. + :rtype: numpy.ndarray of length 3. + """ + if direct: + matrix = self.getMatrix(copy=False) + else: + matrix = self.getInverseMatrix(copy=False) + return numpy.dot(matrix[:3, :3], direction[:3]) + + def transformNormal(self, normal, direct=True): + """Apply the transform to a normal: R = (M-1)t * V. + + :param normal: Array-like vector of 3 coordinates. + :param bool direct: Whether to apply the direct (True, the default) + or inverse (False) transform. + :return: The transformed normal. + :rtype: numpy.ndarray of length 3. + """ + if direct: + matrix = self.getInverseMatrix(copy=False).T + else: + matrix = self.getMatrix(copy=False).T + return numpy.dot(matrix[:3, :3], normal[:3]) + + _CUBE_CORNERS = numpy.array(list(itertools.product((0., 1.), repeat=3)), + dtype=numpy.float32) + """Unit cube corners used by :meth:`transformRectangularBox`""" + + def transformBounds(self, bounds, direct=True): + """Apply the transform to an axes-aligned rectangular box. + + :param bounds: Min and max coords of the box for each axes. + :type bounds: 2x3 numpy.ndarray + :param bool direct: Whether to apply the direct (True, the default) + or inverse (False) transform. + :return: Axes-aligned rectangular box including the transformed box. + :rtype: 2x3 numpy.ndarray of float32 + """ + corners = numpy.ones((8, 4), dtype=numpy.float32) + corners[:, :3] = bounds[0] + \ + self._CUBE_CORNERS * (bounds[1] - bounds[0]) + + if direct: + matrix = self.getMatrix(copy=False) + else: + matrix = self.getInverseMatrix(copy=False) + + # Transform corners + cornerstransposed = numpy.dot(matrix, corners.T) + cornerstransposed = cornerstransposed / cornerstransposed[3] + + # Get min/max for each axis + transformedbounds = numpy.empty((2, 3), dtype=numpy.float32) + transformedbounds[0] = cornerstransposed.T[:, :3].min(axis=0) + transformedbounds[1] = cornerstransposed.T[:, :3].max(axis=0) + + return transformedbounds + + +class Inverse(Transform): + """Transform which is the inverse of another one. + + Static: It never gets updated. + """ + + def __init__(self, transform): + """Initializer. + + :param Transform transform: The transform to invert. + """ + + super(Inverse, self).__init__(static=True) + self._matrix = transform.getInverseMatrix(copy=True) + self._inverse = transform.getMatrix(copy=True) + + +class TransformList(Transform, event.HookList): + """List of transforms.""" + + def __init__(self, iterable=()): + Transform.__init__(self) + event.HookList.__init__(self, iterable) + + def _listWillChangeHook(self, methodName, *args, **kwargs): + for item in self: + item.removeListener(self._transformChanged) + + def _listWasChangedHook(self, methodName, *args, **kwargs): + for item in self: + item.addListener(self._transformChanged) + self.notify() + + def _transformChanged(self, source): + """Listen to transform changes of the list and its items.""" + if source is not self: # Avoid infinite recursion + self.notify() + + def _makeMatrix(self): + matrix = numpy.identity(4, dtype=numpy.float32) + for transform in self: + matrix = numpy.dot(matrix, transform.getMatrix(copy=False)) + return matrix + + +class StaticTransformList(Transform): + """Transform that is a snapshot of a list of Transforms + + It does not keep reference to the list of Transforms. + + :param iterable: Iterable of Transform used for initialization + """ + + def __init__(self, iterable=()): + super(StaticTransformList, self).__init__(static=True) + matrix = numpy.identity(4, dtype=numpy.float32) + for transform in iterable: + matrix = numpy.dot(matrix, transform.getMatrix(copy=False)) + self._matrix = matrix # Init matrix once + + +# Affine ###################################################################### + +class Matrix(Transform): + + def __init__(self, matrix=None): + """4x4 Matrix. + + :param matrix: 4x4 array-like matrix or None for identity matrix. + """ + super(Matrix, self).__init__(static=True) + self.setMatrix(matrix) + + def setMatrix(self, matrix=None): + """Update the 4x4 Matrix. + + :param matrix: 4x4 array-like matrix or None for identity matrix. + """ + if matrix is None: + self._matrix = numpy.identity(4, dtype=numpy.float32) + else: + matrix = numpy.array(matrix, copy=True, dtype=numpy.float32) + assert matrix.shape == (4, 4) + self._matrix = matrix + # Reset cached inverse as Transform is declared static + self._inverse = None + self.notify() + + # Redefined here to add a setter + matrix = property(Transform.getMatrix, setMatrix, + doc="The 4x4 matrix of this transform.") + + +class Translate(Transform): + """4x4 translation matrix.""" + + def __init__(self, tx=0., ty=0., tz=0.): + super(Translate, self).__init__() + self._tx, self._ty, self._tz = 0., 0., 0. + self.setTranslate(tx, ty, tz) + + def _makeMatrix(self): + return mat4Translate(self.tx, self.ty, self.tz) + + def _makeInverse(self): + return mat4Translate(-self.tx, -self.ty, -self.tz) + + @property + def tx(self): + return self._tx + + @tx.setter + def tx(self, tx): + self.setTranslate(tx=tx) + + @property + def ty(self): + return self._ty + + @ty.setter + def ty(self, ty): + self.setTranslate(ty=ty) + + @property + def tz(self): + return self._tz + + @tz.setter + def tz(self, tz): + self.setTranslate(tz=tz) + + @property + def translation(self): + return numpy.array((self.tx, self.ty, self.tz), dtype=numpy.float32) + + @translation.setter + def translation(self, translations): + tx, ty, tz = translations + self.setTranslate(tx, ty, tz) + + def setTranslate(self, tx=None, ty=None, tz=None): + if tx is not None: + self._tx = tx + if ty is not None: + self._ty = ty + if tz is not None: + self._tz = tz + self.notify() + + +class Scale(Transform): + """4x4 scale matrix.""" + + def __init__(self, sx=1., sy=1., sz=1.): + super(Scale, self).__init__() + self._sx, self._sy, self._sz = 0., 0., 0. + self.setScale(sx, sy, sz) + + def _makeMatrix(self): + return mat4Scale(self.sx, self.sy, self.sz) + + def _makeInverse(self): + return mat4Scale(1. / self.sx, 1. / self.sy, 1. / self.sz) + + @property + def sx(self): + return self._sx + + @sx.setter + def sx(self, sx): + self.setScale(sx=sx) + + @property + def sy(self): + return self._sy + + @sy.setter + def sy(self, sy): + self.setScale(sy=sy) + + @property + def sz(self): + return self._sz + + @sz.setter + def sz(self, sz): + self.setScale(sz=sz) + + @property + def scale(self): + return numpy.array((self._sx, self._sy, self._sz), dtype=numpy.float32) + + @scale.setter + def scale(self, scales): + sx, sy, sz = scales + self.setScale(sx, sy, sz) + + def setScale(self, sx=None, sy=None, sz=None): + if sx is not None: + assert sx != 0. + self._sx = sx + if sy is not None: + assert sy != 0. + self._sy = sy + if sz is not None: + assert sz != 0. + self._sz = sz + self.notify() + + +class Rotate(Transform): + + def __init__(self, angle=0., ax=0., ay=0., az=1.): + """4x4 rotation matrix. + + :param float angle: The rotation angle in degrees. + :param float ax: The x coordinate of the rotation axis. + :param float ay: The y coordinate of the rotation axis. + :param float az: The z coordinate of the rotation axis. + """ + super(Rotate, self).__init__() + self._angle = 0. + self._axis = None + self.setAngleAxis(angle, (ax, ay, az)) + + @property + def angle(self): + """The rotation angle in degrees.""" + return self._angle + + @angle.setter + def angle(self, angle): + self.setAngleAxis(angle=angle) + + @property + def axis(self): + """The normalized rotation axis as a numpy.ndarray.""" + return self._axis.copy() + + @axis.setter + def axis(self, axis): + self.setAngleAxis(axis=axis) + + def setAngleAxis(self, angle=None, axis=None): + """Update the angle and/or axis of the rotation. + + :param float angle: The rotation angle in degrees. + :param axis: Array-like axis vector (3 coordinates). + """ + if angle is not None: + self._angle = angle + if axis is not None: + assert len(axis) == 3 + axis = numpy.array(axis, copy=True, dtype=numpy.float32) + assert axis.size == 3 + norm = numpy.linalg.norm(axis) + if norm == 0.: # No axis, set rotation angle to 0. + self._angle = 0. + self._axis = numpy.array((0., 0., 1.), dtype=numpy.float32) + else: + self._axis = axis / norm + + if angle is not None or axis is not None: + self.notify() + + @property + def quaternion(self): + """Rotation unit quaternion as (x, y, z, w). + + Where: ||(x, y, z)|| = sin(angle/2), w = cos(angle/2). + """ + if numpy.linalg.norm(self._axis) == 0.: + return numpy.array((0., 0., 0., 1.), dtype=numpy.float32) + + else: + quaternion = numpy.empty((4,), dtype=numpy.float32) + halfangle = 0.5 * numpy.radians(self.angle) + quaternion[0:3] = numpy.sin(halfangle) * self._axis + quaternion[3] = numpy.cos(halfangle) + return quaternion + + @quaternion.setter + def quaternion(self, quaternion): + assert len(quaternion) == 4 + + # Normalize quaternion + quaternion = numpy.array(quaternion, copy=True) + quaternion /= numpy.linalg.norm(quaternion) + + # Get angle + sinhalfangle = numpy.linalg.norm(quaternion[0:3]) + coshalfangle = quaternion[3] + angle = 2. * numpy.arctan2(sinhalfangle, coshalfangle) + + # Axis will be normalized in setAngleAxis + self.setAngleAxis(numpy.degrees(angle), quaternion[0:3]) + + def _makeMatrix(self): + angle = numpy.radians(self.angle, dtype=numpy.float32) + return mat4RotateFromAngleAxis(angle, *self.axis) + + def _makeInverse(self): + return numpy.array(self.getMatrix(copy=False).transpose(), + copy=True, order='C', + dtype=numpy.float32) + + +class Shear(Transform): + + def __init__(self, axis, sx=0., sy=0., sz=0.): + """4x4 shear/skew matrix of 2 axes relative to the third one. + + :param str axis: The axis to keep fixed, in 'x', 'y', 'z' + :param float sx: The shear factor for the x axis. + :param float sy: The shear factor for the y axis. + :param float sz: The shear factor for the z axis. + """ + assert axis in ('x', 'y', 'z') + super(Shear, self).__init__() + self._axis = axis + self._factors = sx, sy, sz + + @property + def axis(self): + """The axis against which other axes are skewed.""" + return self._axis + + @property + def factors(self): + """The shear factors: shearFactor = tan(shearAngle)""" + return self._factors + + def _makeMatrix(self): + return mat4Shear(self.axis, *self.factors) + + def _makeInverse(self): + sx, sy, sz = self.factors + return mat4Shear(self.axis, -sx, -sy, -sz) + + +# Projection ################################################################## + +class _Projection(Transform): + """Base class for projection matrix. + + Handles near and far clipping plane values. + Subclasses must implement :meth:`_makeMatrix`. + + :param float near: Distance to the near plane. + :param float far: Distance to the far plane. + :param bool checkDepthExtent: Toggle checks near > 0 and far > near. + :param size: Viewport's size used to compute the aspect ratio. + :type size: 2-tuple of float (width, height). + """ + + def __init__(self, near, far, checkDepthExtent=False, size=(1., 1.)): + super(_Projection, self).__init__() + self._checkDepthExtent = checkDepthExtent + self._depthExtent = 1, 10 + self.setDepthExtent(near, far) # set _depthExtent + self._size = 1., 1. + self.size = size # set _size + + def setDepthExtent(self, near=None, far=None): + """Set the extent of the visible area along the viewing direction. + + :param float near: The near clipping plane Z coord. + :param float far: The far clipping plane Z coord. + """ + near = float(near) if near is not None else self._depthExtent[0] + far = float(far) if far is not None else self._depthExtent[1] + + if self._checkDepthExtent: + assert near > 0. + assert far > near + + self._depthExtent = near, far + self.notify() + + @property + def near(self): + """Distance to the near plane.""" + return self._depthExtent[0] + + @near.setter + def near(self, near): + if near != self.near: + self.setDepthExtent(near=near) + + @property + def far(self): + """Distance to the far plane.""" + return self._depthExtent[1] + + @far.setter + def far(self, far): + if far != self.far: + self.setDepthExtent(far=far) + + @property + def size(self): + """Viewport size as a 2-tuple of float (width, height).""" + return self._size + + @size.setter + def size(self, size): + assert len(size) == 2 + self._size = tuple(size) + self.notify() + + +class Orthographic(_Projection): + """Orthographic (i.e., parallel) projection which keeps aspect ratio. + + Clipping planes are adjusted to match the aspect ratio of + the :attr:`size` attribute. + + The left, right, bottom and top parameters defines the area which must + always remain visible. + Effective clipping planes are adjusted to keep the aspect ratio. + + :param float left: Coord of the left clipping plane. + :param float right: Coord of the right clipping plane. + :param float bottom: Coord of the bottom clipping plane. + :param float top: Coord of the top clipping plane. + :param float near: Distance to the near plane. + :param float far: Distance to the far plane. + :param size: Viewport's size used to compute the aspect ratio. + :type size: 2-tuple of float (width, height). + """ + + def __init__(self, left=0., right=1., bottom=1., top=0., near=-1., far=1., + size=(1., 1.)): + self._left, self._right = left, right + self._bottom, self._top = bottom, top + super(Orthographic, self).__init__(near, far, checkDepthExtent=False, + size=size) + # _update called when setting size + + def _makeMatrix(self): + return mat4Orthographic( + self.left, self.right, self.bottom, self.top, self.near, self.far) + + def _update(self, left, right, bottom, top): + width, height = self.size + aspect = width / height + + orthoaspect = abs(left - right) / abs(bottom - top) + + if orthoaspect >= aspect: # Keep width, enlarge height + newheight = \ + numpy.sign(top - bottom) * abs(left - right) / aspect + bottom = 0.5 * (bottom + top) - 0.5 * newheight + top = bottom + newheight + + else: # Keep height, enlarge width + newwidth = \ + numpy.sign(right - left) * abs(bottom - top) * aspect + left = 0.5 * (left + right) - 0.5 * newwidth + right = left + newwidth + + # Store values + self._left, self._right = left, right + self._bottom, self._top = bottom, top + + def setClipping(self, left=None, right=None, bottom=None, top=None): + """Set the clipping planes of the projection. + + Parameters are adjusted to keep aspect ratio. + If a clipping plane coord is not provided, it uses its current value + + :param float left: Coord of the left clipping plane. + :param float right: Coord of the right clipping plane. + :param float bottom: Coord of the bottom clipping plane. + :param float top: Coord of the top clipping plane. + """ + left = float(left) if left is not None else self.left + right = float(right) if right is not None else self.right + bottom = float(bottom) if bottom is not None else self.bottom + top = float(top) if top is not None else self.top + + self._update(left, right, bottom, top) + self.notify() + + left = property(lambda self: self._left, + doc="Coord of the left clipping plane.") + + right = property(lambda self: self._right, + doc="Coord of the right clipping plane.") + + bottom = property(lambda self: self._bottom, + doc="Coord of the bottom clipping plane.") + + top = property(lambda self: self._top, + doc="Coord of the top clipping plane.") + + @property + def size(self): + """Viewport size as a 2-tuple of float (width, height) or None.""" + return self._size + + @size.setter + def size(self, size): + assert len(size) == 2 + self._size = float(size[0]), float(size[1]) + self._update(self.left, self.right, self.bottom, self.top) + self.notify() + + +class Ortho2DWidget(_Projection): + """Orthographic projection with pixel as unit. + + Provides same coordinates as widgets: + origin: top left, X axis goes left, Y axis goes down. + + :param float near: Z coordinate of the near clipping plane. + :param float far: Z coordinante of the far clipping plane. + :param size: Viewport's size used to compute the aspect ratio. + :type size: 2-tuple of float (width, height). + """ + + def __init__(self, near=-1., far=1., size=(1., 1.)): + + super(Ortho2DWidget, self).__init__(near, far, size) + + def _makeMatrix(self): + width, height = self.size + return mat4Orthographic(0., width, height, 0., self.near, self.far) + + +class Perspective(_Projection): + """Perspective projection matrix defined by FOV and aspect ratio. + + :param float fovy: Vertical field-of-view in degrees. + :param float near: The near clipping plane Z coord (stricly positive). + :param float far: The far clipping plane Z coord (> near). + :param size: Viewport's size used to compute the aspect ratio. + :type size: 2-tuple of float (width, height). + """ + + def __init__(self, fovy=90., near=0.1, far=1., size=(1., 1.)): + + super(Perspective, self).__init__(near, far, checkDepthExtent=True) + self._fovy = 90. + self.fovy = fovy # Set _fovy + self.size = size # Set _ size + + def _makeMatrix(self): + width, height = self.size + return mat4Perspective(self.fovy, width, height, self.near, self.far) + + @property + def fovy(self): + """Vertical field-of-view in degrees.""" + return self._fovy + + @fovy.setter + def fovy(self, fovy): + self._fovy = float(fovy) + self.notify() diff --git a/silx/gui/plot3d/scene/utils.py b/silx/gui/plot3d/scene/utils.py new file mode 100644 index 0000000..930a087 --- /dev/null +++ b/silx/gui/plot3d/scene/utils.py @@ -0,0 +1,516 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +This module provides functions to generate indices, to check intersection +and to handle planes. +""" + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "25/07/2016" + + +import logging +import numpy + +from . import event + + +_logger = logging.getLogger(__name__) + + +# numpy ####################################################################### + +def _uniqueAlongLastAxis(a): + """Numpy unique on the last axis of a 2D array + + Implemented here as not in numpy as of writing. + + See adding axis parameter to numpy.unique: + https://github.com/numpy/numpy/pull/3584/files#r6225452 + + :param array_like a: Input array. + :return: Unique elements along the last axis. + :rtype: numpy.ndarray + """ + assert len(a.shape) == 2 + + # Construct a type over last array dimension to run unique on a 1D array + if a.dtype.char in numpy.typecodes['AllInteger']: + # Bit-wise comparison of the 2 indices of a line at once + # Expect a C contiguous array of shape N, 2 + uniquedt = numpy.dtype((numpy.void, a.itemsize * a.shape[-1])) + elif a.dtype.char in numpy.typecodes['Float']: + uniquedt = [('f{i}'.format(i=i), a.dtype) for i in range(a.shape[-1])] + else: + raise TypeError("Unsupported type {dtype}".format(dtype=a.dtype)) + + uniquearray = numpy.unique(numpy.ascontiguousarray(a).view(uniquedt)) + return uniquearray.view(a.dtype).reshape((-1, a.shape[-1])) + + +# conversions ################################################################# + +def triangleToLineIndices(triangleIndices, unicity=False): + """Generates lines indices from triangle indices. + + This is generating lines indices for the edges of the triangles. + + :param triangleIndices: The indices to draw a set of vertices as triangles. + :type triangleIndices: numpy.ndarray + :param bool unicity: If True remove duplicated lines, + else (the default) returns all lines. + :return: The indices to draw the edges of the triangles as lines. + :rtype: 1D numpy.ndarray of uint16 or uint32. + """ + # Makes sure indices ar packed by triangle + triangleIndices = triangleIndices.reshape(-1, 3) + + # Pack line indices by triangle and by edge + lineindices = numpy.empty((len(triangleIndices), 3, 2), + dtype=triangleIndices.dtype) + lineindices[:, 0] = triangleIndices[:, :2] # edge = t0, t1 + lineindices[:, 1] = triangleIndices[:, 1:] # edge =t1, t2 + lineindices[:, 2] = triangleIndices[:, ::2] # edge = t0, t2 + + if unicity: + lineindices = _uniqueAlongLastAxis(lineindices.reshape(-1, 2)) + + # Make sure it is 1D + lineindices.shape = -1 + + return lineindices + + +def verticesNormalsToLines(vertices, normals, scale=1.): + """Return vertices of lines representing normals at given positions. + + :param vertices: Positions of the points. + :type vertices: numpy.ndarray with shape: (nbPoints, 3) + :param normals: Corresponding normals at the points. + :type normals: numpy.ndarray with shape: (nbPoints, 3) + :param float scale: The scale factor to apply to normals. + :returns: Array of vertices to draw corresponding lines. + :rtype: numpy.ndarray with shape: (nbPoints * 2, 3) + """ + linevertices = numpy.empty((len(vertices) * 2, 3), dtype=vertices.dtype) + linevertices[0::2] = vertices + linevertices[1::2] = vertices + scale * normals + return linevertices + + +def unindexArrays(mode, indices, *arrays): + """Convert indexed GL primitives to unindexed ones. + + Given indices in arrays and the OpenGL primitive they represent, + return the unindexed equivalent. + + :param str mode: + Kind of primitive represented by indices. + In: points, lines, line_strip, loop, triangles, triangle_strip, fan. + :param indices: Indices in other arrays + :type indices: numpy.ndarray of dimension 1. + :param arrays: Remaining arguments are arrays to convert + :return: Converted arrays + :rtype: tuple of numpy.ndarray + """ + indices = numpy.array(indices, copy=False) + + assert mode in ('points', + 'lines', 'line_strip', 'loop', + 'triangles', 'triangle_strip', 'fan') + + if mode in ('lines', 'line_strip', 'loop'): + assert len(indices) >= 2 + elif mode in ('triangles', 'triangle_strip', 'fan'): + assert len(indices) >= 3 + + assert indices.min() >= 0 + max_index = indices.max() + for data in arrays: + assert len(data) >= max_index + + if mode == 'line_strip': + unpacked = numpy.empty((2 * (len(indices) - 1),), dtype=indices.dtype) + unpacked[0::2] = indices[:-1] + unpacked[1::2] = indices[1:] + indices = unpacked + + elif mode == 'loop': + unpacked = numpy.empty((2 * len(indices),), dtype=indices.dtype) + unpacked[0::2] = indices + unpacked[1:-1:2] = indices[1:] + unpacked[-1] = indices[0] + indices = unpacked + + elif mode == 'triangle_strip': + unpacked = numpy.empty((3 * (len(indices) - 2),), dtype=indices.dtype) + unpacked[0::3] = indices[:-2] + unpacked[1::3] = indices[1:-1] + unpacked[2::3] = indices[2:] + indices = unpacked + + elif mode == 'fan': + unpacked = numpy.empty((3 * (len(indices) - 2),), dtype=indices.dtype) + unpacked[0::3] = indices[0] + unpacked[1::3] = indices[1:-1] + unpacked[2::3] = indices[2:] + indices = unpacked + + return tuple(numpy.ascontiguousarray(data[indices]) for data in arrays) + + +def trianglesNormal(positions): + """Return normal for each triangle. + + :param positions: Serie of triangle's corners + :type positions: numpy.ndarray of shape (NbTriangles*3, 3) + :return: Normals corresponding to each position. + :rtype: numpy.ndarray of shape (NbTriangles, 3) + """ + assert positions.ndim == 2 + assert positions.shape[1] == 3 + + positions = numpy.array(positions, copy=False).reshape(-1, 3, 3) + + normals = numpy.cross(positions[:, 1] - positions[:, 0], + positions[:, 2] - positions[:, 0]) + + # Normalize normals + if numpy.version.version < '1.8.0': + # Debian 7 support: numpy.linalg.norm has no axis argument + norms = numpy.array(tuple(numpy.linalg.norm(vec) for vec in normals), + dtype=normals.dtype) + else: + norms = numpy.linalg.norm(normals, axis=1) + norms[norms == 0] = 1 + + return normals / norms.reshape(-1, 1) + + +# grid ######################################################################## + +def gridVertices(dim0Array, dim1Array, dtype): + """Generate an array of 2D positions from 2 arrays of 1D coordinates. + + :param dim0Array: 1D array-like of coordinates along the first dimension. + :param dim1Array: 1D array-like of coordinates along the second dimension. + :param numpy.dtype dtype: Data type of the output array. + :return: Array of grid coordinates. + :rtype: numpy.ndarray with shape: (len(dim0Array), len(dim1Array), 2) + """ + grid = numpy.empty((len(dim0Array), len(dim1Array), 2), dtype=dtype) + grid.T[0, :, :] = dim0Array + grid.T[1, :, :] = numpy.array(dim1Array, copy=False)[:, None] + return grid + + +def triangleStripGridIndices(dim0, dim1): + """Generate indices to draw a grid of vertices as a triangle strip. + + Vertices are expected to be stored as row-major (i.e., C contiguous). + + :param int dim0: The number of rows of vertices. + :param int dim1: The number of columns of vertices. + :return: The vertex indices + :rtype: 1D numpy.ndarray of uint32 + """ + assert dim0 >= 2 + assert dim1 >= 2 + + # Filling a row of squares + + # an index before and one after for degenerated triangles + indices = numpy.empty((dim0 - 1, 2 * (dim1 + 1)), dtype=numpy.uint32) + + # Init indices with minimum indices for each row of squares + indices[:] = (dim1 * numpy.arange(dim0 - 1, dtype=numpy.uint32))[:, None] + + # Update indices with offset per row of squares + offset = numpy.arange(dim1, dtype=numpy.uint32) + indices[:, 1:-1:2] += offset + offset += dim1 + indices[:, 2::2] += offset + indices[:, -1] += offset[-1] + + # Remove extra indices for degenerated triangles before returning + return indices.ravel()[1:-1] + + # Alternative: + # indices = numpy.zeros(2 * dim1 * (dim0 - 1) + 2 * (dim0 - 2), + # dtype=numpy.uint32) + # + # offset = numpy.arange(dim1, dtype=numpy.uint32) + # for d0Index in range(dim0 - 1): + # start = 2 * d0Index * (dim1 + 1) + # end = start + 2 * dim1 + # if d0Index != 0: + # indices[start - 2] = offset[-1] + # indices[start - 1] = offset[0] + # indices[start:end:2] = offset + # offset += dim1 + # indices[start + 1:end:2] = offset + # return indices + + +def linesGridIndices(dim0, dim1): + """Generate indices to draw a grid of vertices as lines. + + Vertices are expected to be stored as row-major (i.e., C contiguous). + + :param int dim0: The number of rows of vertices. + :param int dim1: The number of columns of vertices. + :return: The vertex indices. + :rtype: 1D numpy.ndarray of uint32 + """ + # Horizontal and vertical lines + nbsegmentalongdim1 = 2 * (dim1 - 1) + nbsegmentalongdim0 = 2 * (dim0 - 1) + + indices = numpy.empty(nbsegmentalongdim1 * dim0 + + nbsegmentalongdim0 * dim1, + dtype=numpy.uint32) + + # Line indices over dim0 + onedim1line = (numpy.arange(nbsegmentalongdim1, + dtype=numpy.uint32) + 1) // 2 + indices[:dim0 * nbsegmentalongdim1] = \ + (dim1 * numpy.arange(dim0, dtype=numpy.uint32)[:, None] + + onedim1line[None, :]).ravel() + + # Line indices over dim1 + onedim0line = (numpy.arange(nbsegmentalongdim0, + dtype=numpy.uint32) + 1) // 2 + indices[dim0 * nbsegmentalongdim1:] = \ + (numpy.arange(dim1, dtype=numpy.uint32)[:, None] + + dim1 * onedim0line[None, :]).ravel() + + return indices + + +# intersection ################################################################ + +def angleBetweenVectors(refVector, vectors, norm=None): + """Return the angle between 2 vectors. + + :param refVector: Coordinates of the reference vector. + :type refVector: numpy.ndarray of shape: (NCoords,) + :param vectors: Coordinates of the vector(s) to get angle from reference. + :type vectors: numpy.ndarray of shape: (NCoords,) or (NbVector, NCoords) + :param norm: A direction vector giving an orientation to the angles + or None. + :returns: The angles in radians in [0, pi] if norm is None + else in [0, 2pi]. + :rtype: float or numpy.ndarray of shape (NbVectors,) + """ + singlevector = len(vectors.shape) == 1 + if singlevector: # Make it a 2D array for the computation + vectors = vectors.reshape(1, -1) + + assert len(refVector.shape) == 1 + assert len(vectors.shape) == 2 + assert len(refVector) == vectors.shape[1] + + # Normalize vectors + refVector /= numpy.linalg.norm(refVector) + vectors = numpy.array([v / numpy.linalg.norm(v) for v in vectors]) + + dots = numpy.sum(refVector * vectors, axis=-1) + angles = numpy.arccos(numpy.clip(dots, -1., 1.)) + if norm is not None: + signs = numpy.sum(norm * numpy.cross(refVector, vectors), axis=-1) < 0. + angles[signs] = numpy.pi * 2. - angles[signs] + + return angles[0] if singlevector else angles + + +def segmentPlaneIntersect(s0, s1, planeNorm, planePt): + """Compute the intersection of a segment with a plane. + + :param s0: First end of the segment + :type s0: 1D numpy.ndarray-like of length 3 + :param s1: Second end of the segment + :type s1: 1D numpy.ndarray-like of length 3 + :param planeNorm: Normal vector of the plane. + :type planeNorm: numpy.ndarray of shape: (3,) + :param planePt: A point of the plane. + :type planePt: numpy.ndarray of shape: (3,) + :return: The intersection points. The number of points goes + from 0 (no intersection) to 2 (segment in the plane) + :rtype: list of numpy.ndarray + """ + s0, s1 = numpy.asarray(s0), numpy.asarray(s1) + + segdir = s1 - s0 + dotnormseg = numpy.dot(planeNorm, segdir) + if dotnormseg == 0: + # line and plane are parallels + if numpy.dot(planeNorm, planePt - s0) == 0: # segment is in plane + return [s0, s1] + else: # No intersection + return [] + + alpha = - numpy.dot(planeNorm, s0 - planePt) / dotnormseg + if 0. <= alpha <= 1.: # Intersection with segment + return [s0 + alpha * segdir] + else: # intersection outside segment + return [] + + +def boxPlaneIntersect(boxVertices, boxLineIndices, planeNorm, planePt): + """Return intersection points between a box and a plane. + + :param boxVertices: Position of the corners of the box. + :type boxVertices: numpy.ndarray with shape: (8, 3) + :param boxLineIndices: Indices of the box edges. + :type boxLineIndices: numpy.ndarray-like with shape: (12, 2) + :param planeNorm: Normal vector of the plane. + :type planeNorm: numpy.ndarray of shape: (3,) + :param planePt: A point of the plane. + :type planePt: numpy.ndarray of shape: (3,) + :return: The found intersection points + :rtype: numpy.ndarray with 2 dimensions + """ + segments = numpy.take(boxVertices, boxLineIndices, axis=0) + + points = set() # Gather unique intersection points + for seg in segments: + for point in segmentPlaneIntersect(seg[0], seg[1], planeNorm, planePt): + points.add(tuple(point)) + points = numpy.array(list(points)) + + if len(points) <= 2: + return numpy.array(()) + elif len(points) == 3: + return points + else: # len(points) > 3 + # Order point to have a polyline lying on the unit cube's faces + vectors = points - numpy.mean(points, axis=0) + angles = angleBetweenVectors(vectors[0], vectors, planeNorm) + points = numpy.take(points, numpy.argsort(angles), axis=0) + return points + + +# Plane ####################################################################### + +class Plane(event.Notifier): + """Object handling a plane and notifying plane changes. + + :param point: A point on the plane. + :type point: 3-tuple of float. + :param normal: Normal of the plane. + :type normal: 3-tuple of float. + """ + + def __init__(self, point=(0., 0., 0.), normal=(0., 0., 1.)): + super(Plane, self).__init__() + + assert len(point) == 3 + self._point = numpy.array(point, copy=True, dtype=numpy.float32) + assert len(normal) == 3 + self._normal = numpy.array(normal, copy=True, dtype=numpy.float32) + self.notify() + + def setPlane(self, point=None, normal=None): + """Set plane point and normal and notify. + + :param point: A point on the plane. + :type point: 3-tuple of float or None. + :param normal: Normal of the plane. + :type normal: 3-tuple of float or None. + """ + planechanged = False + + if point is not None: + assert len(point) == 3 + point = numpy.array(point, copy=True, dtype=numpy.float32) + if not numpy.all(numpy.equal(self._point, point)): + self._point = point + planechanged = True + + if normal is not None: + assert len(normal) == 3 + normal = numpy.array(normal, copy=True, dtype=numpy.float32) + + norm = numpy.linalg.norm(normal) + if norm != 0.: + normal /= norm + + if not numpy.all(numpy.equal(self._normal, normal)): + self._normal = normal + planechanged = True + + if planechanged: + _logger.debug('Plane updated:\n\tpoint: %s\n\tnormal: %s', + str(self._point), str(self._normal)) + self.notify() + + @property + def point(self): + """A point on the plane.""" + return self._point.copy() + + @point.setter + def point(self, point): + self.setPlane(point=point) + + @property + def normal(self): + """The (normalized) normal of the plane.""" + return self._normal.copy() + + @normal.setter + def normal(self, normal): + self.setPlane(normal=normal) + + @property + def parameters(self): + """Plane equation parameters: a*x + b*y + c*z + d = 0.""" + return numpy.append(self._normal, + - numpy.dot(self._point, self._normal)) + + @parameters.setter + def parameters(self, parameters): + assert len(parameters) == 4 + parameters = numpy.array(parameters, dtype=numpy.float32) + + # Normalize normal + norm = numpy.linalg.norm(parameters[:3]) + if norm != 0: + parameters /= norm + + normal = parameters[:3] + point = - parameters[3] * normal + self.setPlane(point, normal) + + @property + def isPlane(self): + """True if a plane is defined (i.e., ||normal|| != 0).""" + return numpy.any(self.normal != 0.) + + def move(self, step): + """Move the plane of step along the normal.""" + self.point += step * self.normal diff --git a/silx/gui/plot3d/scene/viewport.py b/silx/gui/plot3d/scene/viewport.py new file mode 100644 index 0000000..83cda43 --- /dev/null +++ b/silx/gui/plot3d/scene/viewport.py @@ -0,0 +1,492 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a class to control a viewport on the rendering window. + +The :class:`Viewport` describes a Viewport rendering a scene. +The attribute :attr:`scene` is the root group of the scene tree. +:class:`RenderContext` handles the current state during rendering. +""" + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "25/07/2016" + + +import numpy + +from silx.gui.plot.Colors import rgba + +from ..._glutils import gl + +from . import camera +from . import event +from . import transform +from .function import DirectionalLight, ClippingPlane + + +class RenderContext(object): + """Handle a current rendering context. + + An instance of this class is passed to rendering method through + the scene during render. + + User should NEVER use an instance of this class beyond the method + it is passed to as an argument (i.e., do not keep a reference to it). + + :param Viewport viewport: The viewport doing the rendering. + :param Context glContext: The operating system OpenGL context in use. + """ + + def __init__(self, viewport, glContext): + self._viewport = viewport + self._glContext = glContext + self._transformStack = [viewport.camera.extrinsic] + self._clipPlane = ClippingPlane(normal=(0., 0., 0.)) + + @property + def viewport(self): + """Viewport doing the current rendering""" + return self._viewport + + @property + def glCtx(self): + """The OpenGL context in use""" + return self._glContext + + @property + def objectToCamera(self): + """The current transform from object to camera coords. + + Do not modify. + """ + return self._transformStack[-1] + + @property + def projection(self): + """Projection transform. + + Do not modify. + """ + return self.viewport.camera.intrinsic + + @property + def objectToNDC(self): + """The transform from object to NDC (this includes projection). + + Do not modify. + """ + return transform.StaticTransformList( + (self.projection, self.objectToCamera)) + + def pushTransform(self, transform_, multiply=True): + """Push a :class:`Transform` on the transform stack. + + :param Transform transform_: The transform to add to the stack. + :param bool multiply: + True (the default) to multiply with the top of the stack, + False to push the transform as is without multiplication. + """ + if multiply: + assert len(self._transformStack) >= 1 + transform_ = transform.StaticTransformList( + (self._transformStack[-1], transform_)) + + self._transformStack.append(transform_) + + def popTransform(self): + """Pop the transform on top of the stack. + + :return: The Transform that is popped from the stack. + """ + assert len(self._transformStack) > 1 + return self._transformStack.pop() + + @property + def clipper(self): + """The current clipping plane + """ + return self._clipPlane + + def setClipPlane(self, point=(0., 0., 0.), normal=(0., 0., 0.)): + """Set the clipping plane to use + + For now only handles a single clipping plane. + + :param point: A point of the plane + :type point: 3-tuple of float + :param normal: Normal vector of the plane or (0, 0, 0) for no clipping + :type normal: 3-tuple of float + """ + self._clipPlane = ClippingPlane(point, normal) + + +class Viewport(event.Notifier): + """Rendering a single scene through a camera in part of a framebuffer. + + :param int framebuffer: The framebuffer ID this viewport is rendering into + """ + + def __init__(self, framebuffer=0): + from . import Group # Here to avoid cyclic import + super(Viewport, self).__init__() + self._dirty = True + self._origin = 0, 0 + self._size = 1, 1 + self._framebuffer = int(framebuffer) + self.scene = Group() # The stuff to render, add overlaid scenes? + self.scene._setParent(self) + self.scene.addListener(self._changed) + self._background = 0., 0., 0., 1. + self._camera = camera.Camera(fovy=30., near=1., far=100., + position=(0., 0., 12.)) + self._camera.addListener(self._changed) + self._transforms = transform.TransformList([self._camera]) + + self._light = DirectionalLight(direction=(0., 0., -1.), + ambient=(0.3, 0.3, 0.3), + diffuse=(0.7, 0.7, 0.7)) + self._light.addListener(self._changed) + + @property + def transforms(self): + """Proxy of camera transforms. + + Do not modify the list. + """ + return self._transforms + + def _changed(self, *args, **kwargs): + """Callback handling scene updates""" + self._dirty = True + self.notify() + + @property + def dirty(self): + """True if scene is dirty and needs redisplay.""" + return self._dirty + + def resetDirty(self): + """Mark the scene as not being dirty. + + To call after rendering. + """ + self._dirty = False + + @property + def background(self): + """Background color of the viewport (4-tuple of float in [0, 1]""" + return self._background + + @background.setter + def background(self, color): + color = rgba(color) + if self._background != color: + self._background = color + self._changed() + + @property + def camera(self): + """The camera used to render the scene.""" + return self._camera + + @property + def light(self): + """The light used to render the scene.""" + return self._light + + @property + def origin(self): + """Origin (ox, oy) of the viewport in pixels""" + return self._origin + + @origin.setter + def origin(self, origin): + ox, oy = origin + origin = int(ox), int(oy) + if origin != self._origin: + self._origin = origin + self._changed() + + @property + def size(self): + """Size (width, height) of the viewport in pixels""" + return self._size + + @size.setter + def size(self, size): + w, h = size + size = int(w), int(h) + if size != self._size: + self._size = size + + self.camera.intrinsic.size = size + self._changed() + + @property + def shape(self): + """Shape (height, width) of the viewport in pixels. + + This is a convenient wrapper to the inverse of size. + """ + return self._size[1], self._size[0] + + @shape.setter + def shape(self, shape): + self.size = shape[1], shape[0] + + @property + def framebuffer(self): + """The framebuffer ID this viewport is rendering into (int)""" + return self._framebuffer + + @framebuffer.setter + def framebuffer(self, framebuffer): + self._framebuffer = int(framebuffer) + + def render(self, glContext): + """Perform the rendering of the viewport + + :param Context glContext: The context used for rendering""" + # Get a chance to run deferred delete + glContext.cleanGLGarbage() + + # OpenGL set-up: really need to be done once + ox, oy = self.origin + w, h = self.size + gl.glViewport(ox, oy, w, h) + + gl.glEnable(gl.GL_SCISSOR_TEST) + gl.glScissor(ox, oy, w, h) + + gl.glEnable(gl.GL_BLEND) + gl.glBlendFunc(gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA) + + gl.glEnable(gl.GL_DEPTH_TEST) + gl.glDepthFunc(gl.GL_LEQUAL) + gl.glDepthRange(0., 1.) + + # gl.glEnable(gl.GL_POLYGON_OFFSET_FILL) + # gl.glPolygonOffset(1., 1.) + + gl.glHint(gl.GL_LINE_SMOOTH_HINT, gl.GL_NICEST) + gl.glEnable(gl.GL_LINE_SMOOTH) + + gl.glClearColor(*self.background) + + # Prepare OpenGL + gl.glClear(gl.GL_COLOR_BUFFER_BIT | + gl.GL_STENCIL_BUFFER_BIT | + gl.GL_DEPTH_BUFFER_BIT) + + ctx = RenderContext(self, glContext) + self.scene.render(ctx) + self.scene.postRender(ctx) + + def adjustCameraDepthExtent(self): + """Update camera depth extent to fit the scene bounds. + + Only near and far planes are updated. + The scene might still not be fully visible + (e.g., if spanning behind the viewpoint with perspective projection). + """ + bounds = self.scene.bounds(transformed=True) + bounds = self.camera.extrinsic.transformBounds(bounds) + + if isinstance(self.camera.intrinsic, transform.Perspective): + # This needs to be reworked + zbounds = - bounds[:, 2] + zextent = max(numpy.fabs(zbounds[0] - zbounds[1]), 0.0001) + near = max(zextent / 1000., 0.95 * zbounds[1]) + far = max(near + 0.1, 1.05 * zbounds[0]) + + self.camera.intrinsic.setDepthExtent(near, far) + elif isinstance(self.camera.intrinsic, transform.Orthographic): + # Makes sure z bounds are included + border = max(abs(bounds[:, 2])) + self.camera.intrinsic.setDepthExtent(-border, border) + else: + raise RuntimeError('Unsupported camera', self.camera.intrinsic) + + def resetCamera(self): + """Change camera to have the whole scene in the viewing frustum. + + It updates the camera position and depth extent. + Camera sight direction and up are not affected. + """ + self.camera.resetCamera(self.scene.bounds(transformed=True)) + + def orbitCamera(self, direction, angle=1.): + """Rotate the camera around center of the scene. + + :param str direction: Direction of movement relative to image plane. + In: 'up', 'down', 'left', 'right'. + :param float angle: he angle in degrees of the rotation. + """ + bounds = self.scene.bounds(transformed=True) + center = 0.5 * (bounds[0] + bounds[1]) + self.camera.orbit(direction, center, angle) + + def moveCamera(self, direction, step=0.1): + """Move the camera relative to the image plane. + + :param str direction: Direction relative to image plane. + One of: 'up', 'down', 'left', 'right', + 'forward', 'backward'. + :param float step: The ratio of data to step for each pan. + """ + bounds = self.scene.bounds(transformed=True) + bounds = self.camera.extrinsic.transformBounds(bounds) + center = 0.5 * (bounds[0] + bounds[1]) + ndcCenter = self.camera.intrinsic.transformPoint( + center, perspectiveDivide=True) + + step *= 2. # NDC has size 2 + + if direction == 'up': + ndcCenter[1] -= step + elif direction == 'down': + ndcCenter[1] += step + + elif direction == 'right': + ndcCenter[0] -= step + elif direction == 'left': + ndcCenter[0] += step + + elif direction == 'forward': + ndcCenter[2] += step + elif direction == 'backward': + ndcCenter[2] -= step + + else: + raise ValueError('Unsupported direction: %s' % direction) + + newCenter = self.camera.intrinsic.transformPoint( + ndcCenter, direct=False, perspectiveDivide=True) + + self.camera.move(direction, numpy.linalg.norm(newCenter - center)) + + def windowToNdc(self, winX, winY, checkInside=True): + """Convert position from window to normalized device coordinates. + + If window coordinates are int, they are moved half a pixel + to be positioned at the center of pixel. + + :param winX: X window coord, origin left. + :param winY: Y window coord, origin top. + :param bool checkInside: If True, returns None if position is + outside viewport. + :return: (x, y) Normalize device coordinates in [-1, 1] or None. + Origin center, x to the right, y goes upward. + """ + ox, oy = self._origin + width, height = self.size + + # If int, move it to the center of pixel + if isinstance(winX, int): + winX += 0.5 + if isinstance(winY, int): + winY += 0.5 + + x, y = winX - ox, winY - oy + + if checkInside and (x < 0. or x > width or y < 0. or y > height): + return None # Out of viewport + + ndcx = 2. * x / float(width) - 1. + ndcy = 1. - 2. * y / float(height) + return ndcx, ndcy + + def ndcToWindow(self, ndcX, ndcY, checkInside=True): + """Convert position from normalized device coordinates (NDC) to window. + + :param float ndcX: X NDC coord. + :param float ndcY: Y NDC coord. + :param bool checkInside: If True, returns None if position is + outside viewport. + :return: (x, y) window coordinates or None. + Origin top-left, x to the right, y goes downward. + """ + if (checkInside and + (ndcX < -1. or ndcX > 1. or ndcY < -1. or ndcY > 1.)): + return None # Outside viewport + + ox, oy = self._origin + width, height = self.size + + winx = ox + width * 0.5 * (ndcX + 1.) + winy = oy + height * 0.5 * (1. - ndcY) + return winx, winy + + def _pickNdcZGL(self, x, y): + """Retrieve depth from depth buffer and return corresponding NDC Z. + + :param int x: In pixels in window coordinates, origin left. + :param int y: In pixels in window coordinates, origin top. + :return: Normalize device Z coordinate of depth in [-1, 1] + or None if outside viewport. + :rtype: float or None + """ + ox, oy = self._origin + width, height = self.size + + x = int(x) + y = height - int(y) # Invert y coord + + if x < ox or x > ox + width or y < oy or y > oy + height: + # Outside viewport + return None + + # Get depth from depth buffer in [0., 1.] + # Bind used framebuffer to get depth + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.framebuffer) + depth = gl.glReadPixels( + x, y, 1, 1, gl.GL_DEPTH_COMPONENT, gl.GL_FLOAT)[0] + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0) + # This is not GL|ES friendly + + # Z in NDC in [-1., 1.] + return float(depth) * 2. - 1. + + def _getXZYGL(self, x, y): + ndc = self.windowToNdc(x, y) + if ndc is None: + return None # Outside viewport + ndcz = self._pickNdcZGL(x, y) + ndcpos = numpy.array((ndc[0], ndc[1], ndcz, 1.), dtype=numpy.float32) + + camerapos = self.camera.intrinsic.transformPoint( + ndcpos, direct=False, perspectiveDivide=True) + + scenepos = self.camera.extrinsic.transformPoint(camerapos, + direct=False) + return scenepos[:3] + + def pick(self, x, y): + pass + # ndcX, ndcY = self.windowToNdc(x, y) + # ndcNearPt = ndcX, ndcY, -1. + # ndcFarPT = ndcX, ndcY, 1. diff --git a/silx/gui/plot3d/scene/window.py b/silx/gui/plot3d/scene/window.py new file mode 100644 index 0000000..ad7e6e5 --- /dev/null +++ b/silx/gui/plot3d/scene/window.py @@ -0,0 +1,420 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides a class for Viewports rendering on the screen. + +The :class:`Window` renders a list of Viewports in the current framebuffer. +The rendering can be performed in an off-screen framebuffer that is only +updated when the scene has changed and not each time Qt is requiring a repaint. + +The :class:`Context` and :class:`ContextGL2` represent the operating system +OpenGL context and handle OpenGL resources. +""" + +from __future__ import absolute_import, division, unicode_literals + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "10/01/2017" + + +import weakref +import numpy + +from ..._glutils import gl +from ... import _glutils + +from . import event + + +class Context(object): + """Correspond to an operating system OpenGL context. + + User should NEVER use an instance of this class beyond the method + it is passed to as an argument (i.e., do not keep a reference to it). + + :param glContextHandle: System specific OpenGL context handle. + """ + + def __init__(self, glContextHandle): + self._context = glContextHandle + self._isCurrent = False + self._devicePixelRatio = 1.0 + + @property + def isCurrent(self): + """Whether this OpenGL context is the current one or not.""" + return self._isCurrent + + def setCurrent(self, isCurrent=True): + """Set the state of the OpenGL context to reflect OpenGL state. + + This should not be called from the scene graph, only in the + wrapper that handle the OpenGL context to reflect its state. + + :param bool isCurrent: The state of the system OpenGL context. + """ + self._isCurrent = bool(isCurrent) + + @property + def devicePixelRatio(self): + """Ratio between device and device independent pixels (float) + + This is useful for font rendering. + """ + return self._devicePixelRatio + + @devicePixelRatio.setter + def devicePixelRatio(self, ratio): + assert ratio > 0 + self._devicePixelRatio = float(ratio) + + def __enter__(self): + self.setCurrent(True) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.setCurrent(False) + + @property + def glContext(self): + """The handle to the OpenGL context provided by the system.""" + return self._context + + def cleanGLGarbage(self): + """This is releasing OpenGL resource that are no longer used.""" + pass + + +class ContextGL2(Context): + """Handle a system GL2 context. + + User should NEVER use an instance of this class beyond the method + it is passed to as an argument (i.e., do not keep a reference to it). + + :param glContextHandle: System specific OpenGL context handle. + """ + def __init__(self, glContextHandle): + super(ContextGL2, self).__init__(glContextHandle) + + self._programs = {} # GL programs already compiled + self._vbos = {} # GL Vbos already set + self._vboGarbage = [] # Vbos waiting to be discarded + + # programs + + def prog(self, vertexShaderSrc, fragmentShaderSrc): + """Cache program within context. + + WARNING: No clean-up. + """ + assert self.isCurrent + key = vertexShaderSrc, fragmentShaderSrc + prog = self._programs.get(key, None) + if prog is None: + prog = _glutils.Program(vertexShaderSrc, fragmentShaderSrc) + self._programs[key] = prog + return prog + + # VBOs + + def makeVbo(self, data=None, sizeInBytes=None, + usage=None, target=None): + """Create a VBO in this context with the data. + + Current limitations: + + - One array per VBO + - Do not support sharing VertexBuffer across VboAttrib + + Automatically discards the VBO when the returned + :class:`VertexBuffer` istance is deleted. + + :param numpy.ndarray data: 2D array of data to store in VBO or None. + :param int sizeInBytes: Size of the VBO or None. + It should be <= data.nbytes if both are given. + :param usage: OpenGL usage define in VertexBuffer._USAGES. + :param target: OpenGL target in VertexBuffer._TARGETS. + :return: The VertexBuffer created in this context. + """ + assert self.isCurrent + vbo = _glutils.VertexBuffer(data, sizeInBytes, usage, target) + vboref = weakref.ref(vbo, self._deadVbo) + # weakref is hashable as far as target is + self._vbos[vboref] = vbo.name + return vbo + + def makeVboAttrib(self, data, usage=None, target=None): + """Create a VBO from data and returns the associated VBOAttrib. + + Automatically discards the VBO when the returned + :class:`VBOAttrib` istance is deleted. + + :param numpy.ndarray data: 2D array of data to store in VBO or None. + :param usage: OpenGL usage define in VertexBuffer._USAGES. + :param target: OpenGL target in VertexBuffer._TARGETS. + :returns: A VBOAttrib instance created in this context. + """ + assert self.isCurrent + vbo = self.makeVbo(data, usage=usage, target=target) + + assert len(data.shape) <= 2 + dimension = 1 if len(data.shape) == 1 else data.shape[1] + + return _glutils.VertexBufferAttrib( + vbo, + type_=_glutils.numpyToGLType(data.dtype), + size=data.shape[0], + dimension=dimension, + offset=0, + stride=0) + + def _deadVbo(self, vboRef): + """Callback handling dead VBOAttribs.""" + vboid = self._vbos.pop(vboRef) + if self.isCurrent: + # Direct delete if context is active + gl.glDeleteBuffers(vboid) + else: + # Deferred VBO delete if context is not active + self._vboGarbage.append(vboid) + + def cleanGLGarbage(self): + """Delete OpenGL resources that are pending for destruction. + + This requires the associated OpenGL context to be active. + This is meant to be called before rendering. + """ + assert self.isCurrent + if self._vboGarbage: + vboids = self._vboGarbage + gl.glDeleteBuffers(vboids) + self._vboGarbage = [] + + +class Window(event.Notifier): + """OpenGL Framebuffer where to render viewports + + :param str mode: Rendering mode to use: + + - 'direct' to render everything for each render call + - 'framebuffer' to cache viewport rendering in a texture and + update the texture only when needed. + """ + + _position = numpy.array(((-1., -1., 0., 0.), + (1., -1., 1., 0.), + (-1., 1., 0., 1.), + (1., 1., 1., 1.)), + dtype=numpy.float32) + + _shaders = (""" + attribute vec4 position; + varying vec2 textureCoord; + + void main(void) { + gl_Position = vec4(position.x, position.y, 0., 1.); + textureCoord = position.zw; + } + """, + """ + uniform sampler2D texture; + varying vec2 textureCoord; + + void main(void) { + gl_FragColor = texture2D(texture, textureCoord); + } + """) + + def __init__(self, mode='framebuffer'): + super(Window, self).__init__() + self._dirty = True + self._size = 0, 0 + self._contexts = {} # To map system GL context id to Context objects + self._viewports = event.NotifierList() + self._viewports.addListener(self._updated) + self._framebufferid = 0 + self._framebuffers = {} # Cache of framebuffers + + assert mode in ('direct', 'framebuffer') + self._isframebuffer = mode == 'framebuffer' + + @property + def dirty(self): + """True if this object or any attached viewports is dirty.""" + for viewport in self._viewports: + if viewport.dirty: + return True + return self._dirty + + @property + def size(self): + """Size (width, height) of the window in pixels""" + return self._size + + @size.setter + def size(self, size): + w, h = size + size = int(w), int(h) + if size != self._size: + self._size = size + self._dirty = True + self.notify() + + @property + def shape(self): + """Shape (height, width) of the window in pixels. + + This is a convenient wrapper to the reverse of size. + """ + return self._size[1], self._size[0] + + @shape.setter + def shape(self, shape): + self.size = shape[1], shape[0] + + @property + def viewports(self): + """List of viewports to render in the corresponding framebuffer""" + return self._viewports + + @viewports.setter + def viewports(self, iterable): + self._viewports.removeListener(self._updated) + self._viewports = event.NotifierList(iterable) + self._viewports.addListener(self._updated) + self._dirty = True + + def _updated(self, source, *args, **kwargs): + if source is not self: + self._dirty = True + self.notify(*args, **kwargs) + + framebufferid = property(lambda self: self._framebufferid, + doc="Framebuffer ID used to perform rendering") + + def grab(self, glcontext): + """Returns the raster of the scene as an RGB numpy array + + :returns: OpenGL scene RGB bitmap + :rtype: numpy.ndarray of uint8 of dimension (height, width, 3) + """ + height, width = self.shape + image = numpy.empty((height, width, 3), dtype=numpy.uint8) + + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.framebufferid) + gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1) + gl.glReadPixels( + 0, 0, width, height, gl.GL_RGB, gl.GL_UNSIGNED_BYTE, image) + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0) + + # glReadPixels gives bottom to top, + # while images are stored as top to bottom + image = numpy.flipud(image) + + return numpy.array(image, copy=False, order='C') + + def render(self, glcontext, devicePixelRatio): + """Perform the rendering of attached viewports + + :param glcontext: System identifier of the OpenGL context + :param float devicePixelRatio: + Ratio between device and device-independent pixels + """ + if glcontext not in self._contexts: + self._contexts[glcontext] = ContextGL2(glcontext) # New context + + with self._contexts[glcontext] as context: + context.devicePixelRatio = devicePixelRatio + if self._isframebuffer: + self._renderWithOffscreenFramebuffer(context) + else: + self._renderDirect(context) + + self._dirty = False + + def _renderDirect(self, context): + """Perform the direct rendering of attached viewports + + :param Context context: Object wrapping OpenGL context + """ + for viewport in self._viewports: + viewport.framebuffer = self.framebufferid + viewport.render(context) + viewport.resetDirty() + + def _renderWithOffscreenFramebuffer(self, context): + """Renders viewports in a texture and render this texture on screen. + + The texture is updated only if viewport or size has changed. + + :param ContextGL2 context: Object wrappign OpenGL context + """ + if self.dirty or context not in self._framebuffers: + # Need to redraw framebuffer content + + if (context not in self._framebuffers or + self._framebuffers[context].shape != self.shape): + # Need to rebuild framebuffer + + if context in self._framebuffers: + self._framebuffers[context].discard() + + fbo = _glutils.FramebufferTexture(gl.GL_RGBA, + shape=self.shape, + minFilter=gl.GL_NEAREST, + magFilter=gl.GL_NEAREST, + wrap=gl.GL_CLAMP_TO_EDGE) + self._framebuffers[context] = fbo + self._framebufferid = fbo.name + + # Render in framebuffer + with self._framebuffers[context]: + self._renderDirect(context) + + # Render framebuffer texture to screen + fbo = self._framebuffers[context] + height, width = fbo.shape + + program = context.prog(*self._shaders) + program.use() + + gl.glViewport(0, 0, width, height) + gl.glDisable(gl.GL_BLEND) + gl.glDisable(gl.GL_DEPTH_TEST) + gl.glDisable(gl.GL_SCISSOR_TEST) + # gl.glScissor(0, 0, width, height) + gl.glClearColor(0., 0., 0., 0.) + gl.glClear(gl.GL_COLOR_BUFFER_BIT) + gl.glUniform1i(program.uniforms['texture'], fbo.texture.texUnit) + gl.glEnableVertexAttribArray(program.attributes['position']) + gl.glVertexAttribPointer(program.attributes['position'], + 4, + gl.GL_FLOAT, + gl.GL_FALSE, + 0, + self._position) + fbo.texture.bind() + gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self._position)) + gl.glBindTexture(gl.GL_TEXTURE_2D, 0) diff --git a/silx/gui/plot3d/setup.py b/silx/gui/plot3d/setup.py new file mode 100644 index 0000000..b9d626f --- /dev/null +++ b/silx/gui/plot3d/setup.py @@ -0,0 +1,44 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "25/07/2016" + + +from numpy.distutils.misc_util import Configuration + + +def configuration(parent_package='', top_path=None): + config = Configuration('plot3d', parent_package, top_path) + config.add_subpackage('scene') + config.add_subpackage('test') + config.add_subpackage('utils') + return config + + +if __name__ == "__main__": + from numpy.distutils.core import setup + + setup(configuration=configuration) diff --git a/silx/gui/plot3d/test/__init__.py b/silx/gui/plot3d/test/__init__.py new file mode 100644 index 0000000..66a2f62 --- /dev/null +++ b/silx/gui/plot3d/test/__init__.py @@ -0,0 +1,62 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2015-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""plot3d test suite.""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "05/01/2017" + + +import logging +import os +import unittest + + +_logger = logging.getLogger(__name__) + + +def suite(): + test_suite = unittest.TestSuite() + + if os.environ.get('WITH_GL_TEST', 'True') == 'False': + # Explicitly disabled tests + _logger.warning( + "silx.gui.plot3d tests disabled (WITH_GL_TEST=False)") + + class SkipPlot3DTest(unittest.TestCase): + def runTest(self): + self.skipTest( + "silx.gui.plot3d tests disabled (WITH_GL_TEST=False)") + + test_suite.addTest(SkipPlot3DTest()) + return test_suite + + # Import here to avoid loading modules if tests are disabled + + from ..scene import test as test_scene + + test_suite = unittest.TestSuite() + test_suite.addTest(test_scene.suite()) + return test_suite diff --git a/silx/gui/plot3d/utils/__init__.py b/silx/gui/plot3d/utils/__init__.py new file mode 100644 index 0000000..99d3e08 --- /dev/null +++ b/silx/gui/plot3d/utils/__init__.py @@ -0,0 +1,28 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "18/10/2016" diff --git a/silx/gui/plot3d/utils/mng.py b/silx/gui/plot3d/utils/mng.py new file mode 100644 index 0000000..fe79a52 --- /dev/null +++ b/silx/gui/plot3d/utils/mng.py @@ -0,0 +1,121 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides basic writing Mulitple-image Network Graphics files. + +It only supports RGB888 images of the same shape stored as +MNG-VLC (very low complexity) format. +""" + +from __future__ import absolute_import + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "15/12/2016" + + +import logging +import struct +import zlib + +import numpy + +_logger = logging.getLogger(__name__) + + +def _png_chunk(name, data): + """Return a PNG chunk + + :param str name: Chunk type + :param byte data: Chunk payload + """ + length = struct.pack('>I', len(data)) + name = [char.encode('ascii') for char in name] + chunk = struct.pack('cccc', *name) + data + crc = struct.pack('>I', zlib.crc32(chunk) & 0xffffffff) + return length + chunk + crc + + +def convert(images, nb_images=0, fps=25): + """Convert RGB images to MNG-VLC format. + + See http://www.libpng.org/pub/mng/spec/ + See http://www.libpng.org/pub/png/book/ + See http://www.libpng.org/pub/png/spec/1.2/ + + :param images: iterator of RGB888 images + :type images: iterator of numpy.ndarray of dimension 3 + :param int nb_images: The number of images indicated in the MNG header + :param int fps: The frame rate indicated in the MNG header + :return: An iterator of MNG chunks as bytes + """ + first_image = True + + for image in images: + if first_image: + first_image = False + + height, width = image.shape[:2] + + # MNG signature + yield b'\x8aMNG\r\n\x1a\n' + + # MHDR chunk: File header + yield _png_chunk('MHDR', struct.pack( + ">IIIIIII", + width, + height, + fps, # ticks + nb_images + 1, # layer count + nb_images, # frame count + nb_images, # play time + 1)) # profile: MNG-VLC no alpha: only least significant bit 1 + + assert image.shape == (height, width, 3) + assert image.dtype == numpy.dtype('uint8') + + # IHDR chunk: Image header + depth = 8 # 8 bit per channel + color_type = 2 # 'truecolor' = RGB + interlace = 0 # No + yield _png_chunk('IHDR', struct.pack(">IIBBBBB", + width, + height, + depth, + color_type, + 0, 0, interlace)) + + # Add filter 'None' before each scanline + prepared_data = b'\x00' + b'\x00'.join( + line.tostring() for line in image) # TODO optimize that + compressed_data = zlib.compress(prepared_data, 8) + + # IDAT chunk: Payload + yield _png_chunk('IDAT', compressed_data) + + # IEND chunk: Image footer + yield _png_chunk('IEND', b'') + + # MEND chunk: footer + yield _png_chunk('MEND', b'') diff --git a/silx/gui/qt/__init__.py b/silx/gui/qt/__init__.py new file mode 100644 index 0000000..44daa94 --- /dev/null +++ b/silx/gui/qt/__init__.py @@ -0,0 +1,61 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Common wrapper over Python Qt bindings: + +- `PyQt5 <http://pyqt.sourceforge.net/Docs/PyQt5/>`_, +- `PyQt4 <http://pyqt.sourceforge.net/Docs/PyQt4/>`_ or +- `PySide <http://www.pyside.org>`_. + +If a Qt binding is already loaded, it will use it, otherwise the different +Qt bindings are tried in this order: PyQt4, PySide, PyQt5. + +The name of the loaded Qt binding is stored in the BINDING variable. + +This module provides a flat namespace over Qt bindings by importing +all symbols from **QtCore** and **QtGui** packages and if available +from **QtOpenGL** and **QtSvg** packages. +For **PyQt5**, it also imports all symbols from **QtWidgets** and +**QtPrintSupport** packages. + +Example of using :mod:`silx.gui.qt` module: + +>>> from silx.gui import qt +>>> app = qt.QApplication([]) +>>> widget = qt.QWidget() + +For an alternative solution providing a structured namespace, +see `qtpy <https://pypi.python.org/pypi/QtPy/>`_ which +provides the namespace of PyQt5 over PyQt4 and PySide. +""" + +import sys +from ._qt import * # noqa +from ._utils import * # noqa + + +if sys.platform == "darwin": + if BINDING in ["PySide", "PyQt4"]: + from . import _macosx + _macosx.patch_QUrl_toLocalFile() diff --git a/silx/gui/qt/_macosx.py b/silx/gui/qt/_macosx.py new file mode 100644 index 0000000..07f3143 --- /dev/null +++ b/silx/gui/qt/_macosx.py @@ -0,0 +1,68 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +Patches for Mac OS X +""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "30/11/2016" + + +def patch_QUrl_toLocalFile(): + """Apply a monkey-patch on qt.QUrl to allow to reach filename when the URL + come from a MIME data from a file drop. Without, `QUrl.toLocalName` with + some version of Mac OS X returns a path which looks like + `/.file/id=180.112`. + + Qt5 is or will be patch, but Qt4 and PySide are not. + + This fix uses the file URL and use an subprocess with an + AppleScript. The script convert the URI into a posix path. + The interpreter (osascript) is available on default OS X installs. + + See https://bugreports.qt.io/browse/QTBUG-40449 + """ + from ._qt import QUrl + import subprocess + + def QUrl_toLocalFile(self): + path = QUrl._oldToLocalFile(self) + if not path.startswith("/.file/id="): + return path + + url = self.toString() + script = 'get posix path of my posix file \"%s\" -- kthxbai' % url + try: + p = subprocess.Popen(["osascript", "-e", script], stdout=subprocess.PIPE) + out, _err = p.communicate() + if p.returncode == 0: + return out.strip() + except OSError: + pass + return path + + QUrl._oldToLocalFile = QUrl.toLocalFile + QUrl.toLocalFile = QUrl_toLocalFile diff --git a/silx/gui/qt/_pyside_dynamic.py b/silx/gui/qt/_pyside_dynamic.py new file mode 100644 index 0000000..a9246b9 --- /dev/null +++ b/silx/gui/qt/_pyside_dynamic.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- + +# Taken from: https://gist.github.com/cpbotha/1b42a20c8f3eb9bb7cb8 + +# Copyright (c) 2011 Sebastian Wiesner <lunaryorn@gmail.com> +# Modifications by Charl Botha <cpbotha@vxlabs.com> +# * customWidgets support (registerCustomWidget() causes segfault in +# pyside 1.1.2 on Ubuntu 12.04 x86_64) +# * workingDirectory support in loadUi + +# found this here: +# https://github.com/lunaryorn/snippets/blob/master/qt4/designer/pyside_dynamic.py + +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +""" + How to load a user interface dynamically with PySide. + + .. moduleauthor:: Sebastian Wiesner <lunaryorn@gmail.com> +""" + +from __future__ import (print_function, division, unicode_literals, + absolute_import) + +import logging + +from PySide.QtCore import QMetaObject +from PySide.QtUiTools import QUiLoader +from PySide.QtGui import QMainWindow + + +_logger = logging.getLogger(__name__) + + +class UiLoader(QUiLoader): + """ + Subclass :class:`~PySide.QtUiTools.QUiLoader` to create the user interface + in a base instance. + + Unlike :class:`~PySide.QtUiTools.QUiLoader` itself this class does not + create a new instance of the top-level widget, but creates the user + interface in an existing instance of the top-level class. + + This mimics the behaviour of :func:`PyQt4.uic.loadUi`. + """ + + def __init__(self, baseinstance, customWidgets=None): + """ + Create a loader for the given ``baseinstance``. + + The user interface is created in ``baseinstance``, which must be an + instance of the top-level class in the user interface to load, or a + subclass thereof. + + ``customWidgets`` is a dictionary mapping from class name to class + object for widgets that you've promoted in the Qt Designer + interface. Usually, this should be done by calling + registerCustomWidget on the QUiLoader, but + with PySide 1.1.2 on Ubuntu 12.04 x86_64 this causes a segfault. + + ``parent`` is the parent object of this loader. + """ + + QUiLoader.__init__(self, baseinstance) + self.baseinstance = baseinstance + self.customWidgets = customWidgets + + def createWidget(self, class_name, parent=None, name=''): + """ + Function that is called for each widget defined in ui file, + overridden here to populate baseinstance instead. + """ + + if parent is None and self.baseinstance: + # supposed to create the top-level widget, return the base instance + # instead + return self.baseinstance + + else: + if class_name in self.availableWidgets(): + # create a new widget for child widgets + widget = QUiLoader.createWidget(self, class_name, parent, name) + + else: + # if not in the list of availableWidgets, + # must be a custom widget + # this will raise KeyError if the user has not supplied the + # relevant class_name in the dictionary, or TypeError, if + # customWidgets is None + try: + widget = self.customWidgets[class_name](parent) + + except (TypeError, KeyError): + raise Exception('No custom widget ' + class_name + + ' found in customWidgets param of' + + 'UiLoader __init__.') + + if self.baseinstance: + # set an attribute for the new child widget on the base + # instance, just like PyQt4.uic.loadUi does. + setattr(self.baseinstance, name, widget) + + # this outputs the various widget names, e.g. + # sampleGraphicsView, dockWidget, samplesTableView etc. + # print(name) + + return widget + + +def loadUi(uifile, baseinstance=None, package=None, resource_suffix=None): + """ + Dynamically load a user interface from the given ``uifile``. + + ``uifile`` is a string containing a file name of the UI file to load. + + If ``baseinstance`` is ``None``, the a new instance of the top-level widget + will be created. Otherwise, the user interface is created within the given + ``baseinstance``. In this case ``baseinstance`` must be an instance of the + top-level widget class in the UI file to load, or a subclass thereof. In + other words, if you've created a ``QMainWindow`` interface in the designer, + ``baseinstance`` must be a ``QMainWindow`` or a subclass thereof, too. You + cannot load a ``QMainWindow`` UI file with a plain + :class:`~PySide.QtGui.QWidget` as ``baseinstance``. + + :method:`~PySide.QtCore.QMetaObject.connectSlotsByName()` is called on the + created user interface, so you can implemented your slots according to its + conventions in your widget class. + + Return ``baseinstance``, if ``baseinstance`` is not ``None``. Otherwise + return the newly created instance of the user interface. + """ + if package is not None: + _logger.warning( + "loadUi package parameter not implemented with PySide") + if resource_suffix is not None: + _logger.warning( + "loadUi resource_suffix parameter not implemented with PySide") + + loader = UiLoader(baseinstance) + widget = loader.load(uifile) + QMetaObject.connectSlotsByName(widget) + return widget diff --git a/silx/gui/qt/_pyside_missing.py b/silx/gui/qt/_pyside_missing.py new file mode 100644 index 0000000..a7e2781 --- /dev/null +++ b/silx/gui/qt/_pyside_missing.py @@ -0,0 +1,274 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +Python implementation of classes which are not provided by default by PySide. +""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "17/01/2017" + + +from PySide.QtGui import QAbstractProxyModel +from PySide.QtCore import QModelIndex +from PySide.QtCore import Qt +from PySide.QtGui import QItemSelection +from PySide.QtGui import QItemSelectionRange + + +class QIdentityProxyModel(QAbstractProxyModel): + """Python translation of the source code of Qt c++ file""" + + def __init__(self, parent=None): + super(QIdentityProxyModel, self).__init__(parent) + self.__ignoreNextLayoutAboutToBeChanged = False + self.__ignoreNextLayoutChanged = False + self.__persistentIndexes = [] + + def columnCount(self, parent): + parent = self.mapToSource(parent) + return self.sourceModel().columnCount(parent) + + def dropMimeData(self, data, action, row, column, parent): + parent = self.mapToSource(parent) + return self.sourceModel().dropMimeData(data, action, row, column, parent) + + def index(self, row, column, parent=QModelIndex()): + parent = self.mapToSource(parent) + i = self.sourceModel().index(row, column, parent) + return self.mapFromSource(i) + + def insertColumns(self, column, count, parent=QModelIndex()): + parent = self.mapToSource(parent) + return self.sourceModel().insertColumns(column, count, parent) + + def insertRows(self, row, count, parent=QModelIndex()): + parent = self.mapToSource(parent) + return self.sourceModel().insertRows(row, count, parent) + + def mapFromSource(self, sourceIndex): + if self.sourceModel() is None or not sourceIndex.isValid(): + return QModelIndex() + index = self.createIndex(sourceIndex.row(), sourceIndex.column(), sourceIndex.internalPointer()) + return index + + def mapSelectionFromSource(self, sourceSelection): + proxySelection = QItemSelection() + if self.sourceModel() is None: + return proxySelection + + cursor = sourceSelection.constBegin() + end = sourceSelection.constEnd() + while cursor != end: + topLeft = self.mapFromSource(cursor.topLeft()) + bottomRight = self.mapFromSource(cursor.bottomRight()) + proxyRange = QItemSelectionRange(topLeft, bottomRight) + proxySelection.append(proxyRange) + cursor += 1 + return proxySelection + + def mapSelectionToSource(self, proxySelection): + sourceSelection = QItemSelection() + if self.sourceModel() is None: + return sourceSelection + + cursor = proxySelection.constBegin() + end = proxySelection.constEnd() + while cursor != end: + topLeft = self.mapToSource(cursor.topLeft()) + bottomRight = self.mapToSource(cursor.bottomRight()) + sourceRange = QItemSelectionRange(topLeft, bottomRight) + sourceSelection.append(sourceRange) + cursor += 1 + return sourceSelection + + def mapToSource(self, proxyIndex): + if self.sourceModel() is None or not proxyIndex.isValid(): + return QModelIndex() + return self.sourceModel().createIndex(proxyIndex.row(), proxyIndex.column(), proxyIndex.internalPointer()) + + def match(self, start, role, value, hits=1, flags=Qt.MatchFlags(Qt.MatchStartsWith | Qt.MatchWrap)): + if self.sourceModel() is None: + return [] + + start = self.mapToSource(start) + sourceList = self.sourceModel().match(start, role, value, hits, flags) + proxyList = [] + for cursor in sourceList: + proxyList.append(self.mapFromSource(cursor)) + return proxyList + + def parent(self, child): + sourceIndex = self.mapToSource(child) + sourceParent = sourceIndex.parent() + index = self.mapFromSource(sourceParent) + return index + + def removeColumns(self, column, count, parent=QModelIndex()): + parent = self.mapToSource(parent) + return self.sourceModel().removeColumns(column, count, parent) + + def removeRows(self, row, count, parent=QModelIndex()): + parent = self.mapToSource(parent) + return self.sourceModel().removeRows(row, count, parent) + + def rowCount(self, parent=QModelIndex()): + parent = self.mapToSource(parent) + return self.sourceModel().rowCount(parent) + + def setSourceModel(self, newSourceModel): + """Bind and unbind the source model events""" + self.beginResetModel() + + sourceModel = self.sourceModel() + if sourceModel is not None: + sourceModel.rowsAboutToBeInserted.disconnect(self.__rowsAboutToBeInserted) + sourceModel.rowsInserted.disconnect(self.__rowsInserted) + sourceModel.rowsAboutToBeRemoved.disconnect(self.__rowsAboutToBeRemoved) + sourceModel.rowsRemoved.disconnect(self.__rowsRemoved) + sourceModel.rowsAboutToBeMoved.disconnect(self.__rowsAboutToBeMoved) + sourceModel.rowsMoved.disconnect(self.__rowsMoved) + sourceModel.columnsAboutToBeInserted.disconnect(self.__columnsAboutToBeInserted) + sourceModel.columnsInserted.disconnect(self.__columnsInserted) + sourceModel.columnsAboutToBeRemoved.disconnect(self.__columnsAboutToBeRemoved) + sourceModel.columnsRemoved.disconnect(self.__columnsRemoved) + sourceModel.columnsAboutToBeMoved.disconnect(self.__columnsAboutToBeMoved) + sourceModel.columnsMoved.disconnect(self.__columnsMoved) + sourceModel.modelAboutToBeReset.disconnect(self.__modelAboutToBeReset) + sourceModel.modelReset.disconnect(self.__modelReset) + sourceModel.dataChanged.disconnect(self.__dataChanged) + sourceModel.headerDataChanged.disconnect(self.__headerDataChanged) + sourceModel.layoutAboutToBeChanged.disconnect(self.__layoutAboutToBeChanged) + sourceModel.layoutChanged.disconnect(self.__layoutChanged) + + super(QIdentityProxyModel, self).setSourceModel(newSourceModel) + + sourceModel = self.sourceModel() + if sourceModel is not None: + sourceModel.rowsAboutToBeInserted.connect(self.__rowsAboutToBeInserted) + sourceModel.rowsInserted.connect(self.__rowsInserted) + sourceModel.rowsAboutToBeRemoved.connect(self.__rowsAboutToBeRemoved) + sourceModel.rowsRemoved.connect(self.__rowsRemoved) + sourceModel.rowsAboutToBeMoved.connect(self.__rowsAboutToBeMoved) + sourceModel.rowsMoved.connect(self.__rowsMoved) + sourceModel.columnsAboutToBeInserted.connect(self.__columnsAboutToBeInserted) + sourceModel.columnsInserted.connect(self.__columnsInserted) + sourceModel.columnsAboutToBeRemoved.connect(self.__columnsAboutToBeRemoved) + sourceModel.columnsRemoved.connect(self.__columnsRemoved) + sourceModel.columnsAboutToBeMoved.connect(self.__columnsAboutToBeMoved) + sourceModel.columnsMoved.connect(self.__columnsMoved) + sourceModel.modelAboutToBeReset.connect(self.__modelAboutToBeReset) + sourceModel.modelReset.connect(self.__modelReset) + sourceModel.dataChanged.connect(self.__dataChanged) + sourceModel.headerDataChanged.connect(self.__headerDataChanged) + sourceModel.layoutAboutToBeChanged.connect(self.__layoutAboutToBeChanged) + sourceModel.layoutChanged.connect(self.__layoutChanged) + + self.endResetModel() + + def __columnsAboutToBeInserted(self, parent, start, end): + parent = self.mapFromSource(parent) + self.beginInsertColumns(parent, start, end) + + def __columnsAboutToBeMoved(self, sourceParent, sourceStart, sourceEnd, destParent, dest): + sourceParent = self.mapFromSource(sourceParent) + destParent = self.mapFromSource(destParent) + self.beginMoveColumns(sourceParent, sourceStart, sourceEnd, destParent, dest) + + def __columnsAboutToBeRemoved(self, parent, start, end): + parent = self.mapFromSource(parent) + self.beginRemoveColumns(parent, start, end) + + def __columnsInserted(self, parent, start, end): + self.endInsertColumns() + + def __columnsMoved(self, sourceParent, sourceStart, sourceEnd, destParent, dest): + self.endMoveColumns() + + def __columnsRemoved(self, parent, start, end): + self.endRemoveColumns() + + def __dataChanged(self, topLeft, bottomRight): + topLeft = self.mapFromSource(topLeft) + bottomRight = self.mapFromSource(bottomRight) + self.dataChanged(topLeft, bottomRight) + + def __headerDataChanged(self, orientation, first, last): + self.headerDataChanged(orientation, first, last) + + def __layoutAboutToBeChanged(self): + """Store persistent indexes""" + if self.__ignoreNextLayoutAboutToBeChanged: + return + + for proxyPersistentIndex in self.persistentIndexList(): + self.__proxyIndexes.append() + sourcePersistentIndex = self.mapToSource(proxyPersistentIndex) + mapping = proxyPersistentIndex, sourcePersistentIndex + self.__persistentIndexes.append(mapping) + + self.layoutAboutToBeChanged() + + def __layoutChanged(self): + """Restore persistent indexes""" + if self.__ignoreNextLayoutChanged: + return + + for mapping in self.__persistentIndexes: + proxyIndex, sourcePersistentIndex = mapping + sourcePersistentIndex = self.mapFromSource(sourcePersistentIndex) + self.changePersistentIndex(proxyIndex, sourcePersistentIndex) + + self.__persistentIndexes = [] + + self.layoutChanged() + + def __modelAboutToBeReset(self): + self.beginResetModel() + + def __modelReset(self): + self.endResetModel() + + def __rowsAboutToBeInserted(self, parent, start, end): + parent = self.mapFromSource(parent) + self.beginInsertRows(parent, start, end) + + def __rowsAboutToBeMoved(self, sourceParent, sourceStart, sourceEnd, destParent, dest): + sourceParent = self.mapFromSource(sourceParent) + destParent = self.mapFromSource(destParent) + self.beginMoveRows(sourceParent, sourceStart, sourceEnd, destParent, dest) + + def __rowsAboutToBeRemoved(self, parent, start, end): + parent = self.mapFromSource(parent) + self.beginRemoveRows(parent, start, end) + + def __rowsInserted(self, parent, start, end): + self.endInsertRows() + + def __rowsMoved(self, sourceParent, sourceStart, sourceEnd, destParent, dest): + self.endMoveRows() + + def __rowsRemoved(self, parent, start, end): + self.endRemoveRows() diff --git a/silx/gui/qt/_qt.py b/silx/gui/qt/_qt.py new file mode 100644 index 0000000..0962c21 --- /dev/null +++ b/silx/gui/qt/_qt.py @@ -0,0 +1,229 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Common wrapper over Python Qt bindings: + +- `PyQt5 <http://pyqt.sourceforge.net/Docs/PyQt5/>`_, +- `PyQt4 <http://pyqt.sourceforge.net/Docs/PyQt4/>`_ or +- `PySide <http://www.pyside.org>`_. + +If a Qt binding is already loaded, it will use it, otherwise the different +Qt bindings are tried in this order: PyQt4, PySide, PyQt5. + +The name of the loaded Qt binding is stored in the BINDING variable. + +For an alternative solution providing a structured namespace, +see `qtpy <https://pypi.python.org/pypi/QtPy/>`_ which +provides the namespace of PyQt5 over PyQt4 and PySide. +""" + +__authors__ = ["V.A. Sole - ESRF Data Analysis"] +__license__ = "MIT" +__date__ = "17/01/2017" + + +import logging +import sys +import traceback + + +_logger = logging.getLogger(__name__) + + +BINDING = None +"""The name of the Qt binding in use: 'PyQt5', 'PyQt4' or 'PySide'.""" + +QtBinding = None # noqa +"""The Qt binding module in use: PyQt5, PyQt4 or PySide.""" + +HAS_SVG = False +"""True if Qt provides support for Scalable Vector Graphics (QtSVG).""" + +HAS_OPENGL = False +"""True if Qt provides support for OpenGL (QtOpenGL).""" + +# First check for an already loaded wrapper +if 'PySide.QtCore' in sys.modules: + BINDING = 'PySide' + +elif 'PyQt5.QtCore' in sys.modules: + BINDING = 'PyQt5' + +elif 'PyQt4.QtCore' in sys.modules: + BINDING = 'PyQt4' + +else: # Then try Qt bindings + try: + import PyQt4 # noqa + except ImportError: + try: + import PySide # noqa + except ImportError: + try: + import PyQt5 # noqa + except ImportError: + raise ImportError( + 'No Qt wrapper found. Install PyQt4, PyQt5 or PySide.') + else: + BINDING = 'PyQt5' + else: + BINDING = 'PySide' + else: + BINDING = 'PyQt4' + + +if BINDING == 'PyQt4': + _logger.debug('Using PyQt4 bindings') + + if sys.version < "3.0.0": + try: + import sip + + sip.setapi("QString", 2) + sip.setapi("QVariant", 2) + except: + _logger.warning("Cannot set sip API") + + import PyQt4 as QtBinding # noqa + + from PyQt4.QtCore import * # noqa + from PyQt4.QtGui import * # noqa + + try: + from PyQt4.QtOpenGL import * # noqa + except ImportError: + _logger.info("PyQt4.QtOpenGL not available") + HAS_OPENGL = False + else: + HAS_OPENGL = True + + try: + from PyQt4.QtSvg import * # noqa + except ImportError: + _logger.info("PyQt4.QtSvg not available") + HAS_SVG = False + else: + HAS_SVG = True + + from PyQt4.uic import loadUi # noqa + + Signal = pyqtSignal + + Property = pyqtProperty + + Slot = pyqtSlot + +elif BINDING == 'PySide': + _logger.debug('Using PySide bindings') + + import PySide as QtBinding # noqa + + from PySide.QtCore import * # noqa + from PySide.QtGui import * # noqa + + try: + from PySide.QtOpenGL import * # noqa + except ImportError: + _logger.info("PySide.QtOpenGL not available") + HAS_OPENGL = False + else: + HAS_OPENGL = True + + try: + from PySide.QtSvg import * # noqa + except ImportError: + _logger.info("PySide.QtSvg not available") + HAS_SVG = False + else: + HAS_SVG = True + + pyqtSignal = Signal + + # Import loadUi wrapper for PySide + from ._pyside_dynamic import loadUi # noqa + + # Import missing classes + if not hasattr(locals(), "QIdentityProxyModel"): + from ._pyside_missing import QIdentityProxyModel # noqa + +elif BINDING == 'PyQt5': + _logger.debug('Using PyQt5 bindings') + + import PyQt5 as QtBinding # noqa + + from PyQt5.QtCore import * # noqa + from PyQt5.QtGui import * # noqa + from PyQt5.QtWidgets import * # noqa + from PyQt5.QtPrintSupport import * # noqa + + try: + from PyQt5.QtOpenGL import * # noqa + except ImportError: + _logger.info("PySide.QtOpenGL not available") + HAS_OPENGL = False + else: + HAS_OPENGL = True + + try: + from PyQt5.QtSvg import * # noqa + except ImportError: + _logger.info("PyQt5.QtSvg not available") + HAS_SVG = False + else: + HAS_SVG = True + + from PyQt5.uic import loadUi # noqa + + Signal = pyqtSignal + + Property = pyqtProperty + + Slot = pyqtSlot + +else: + raise ImportError('No Qt wrapper found. Install PyQt4, PyQt5 or PySide') + +# provide a exception handler but not implement it by default +def exceptionHandler(type_, value, trace): + """ + This exception handler prevents quitting to the command line when there is + an unhandled exception while processing a Qt signal. + + The script/application willing to use it should implement code similar to: + + .. code-block:: python + + if __name__ == "__main__": + sys.excepthook = qt.exceptionHandler + + """ + _logger.error("%s %s %s", type_, value, ''.join(traceback.format_tb(trace))) + msg = QMessageBox() + msg.setWindowTitle("Unhandled exception") + msg.setIcon(QMessageBox.Critical) + msg.setInformativeText("%s %s\nPlease report details" % (type_, value)) + msg.setDetailedText(("%s " % value) + ''.join(traceback.format_tb(trace))) + msg.raise_() + msg.exec_() + diff --git a/silx/gui/qt/_utils.py b/silx/gui/qt/_utils.py new file mode 100644 index 0000000..0aa3ef1 --- /dev/null +++ b/silx/gui/qt/_utils.py @@ -0,0 +1,44 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides convenient functions related to Qt. +""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "30/11/2016" + +import sys +from ._qt import BINDING, QImageReader + + +def supportedImageFormats(): + """Return a set of string of file format extensions supported by the + Qt runtime.""" + if sys.version_info[0] < 3 or BINDING == 'PySide': + convert = str + else: + convert = lambda data: str(data, 'ascii') + formats = QImageReader.supportedImageFormats() + return set([convert(data) for data in formats]) diff --git a/silx/gui/setup.py b/silx/gui/setup.py new file mode 100644 index 0000000..fbe9058 --- /dev/null +++ b/silx/gui/setup.py @@ -0,0 +1,51 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "16/01/2017" + + +from numpy.distutils.misc_util import Configuration + + +def configuration(parent_package='', top_path=None): + config = Configuration('gui', parent_package, top_path) + config.add_subpackage('_glutils') + config.add_subpackage('qt') + config.add_subpackage('plot') + config.add_subpackage('fit') + config.add_subpackage('hdf5') + config.add_subpackage('widgets') + config.add_subpackage('test') + config.add_subpackage('plot3d') + config.add_subpackage('data') + + return config + + +if __name__ == "__main__": + from numpy.distutils.core import setup + + setup(configuration=configuration) diff --git a/silx/gui/test/__init__.py b/silx/gui/test/__init__.py new file mode 100644 index 0000000..7449860 --- /dev/null +++ b/silx/gui/test/__init__.py @@ -0,0 +1,108 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["T. Vincent", "P. Knobel"] +__license__ = "MIT" +__date__ = "05/01/2017" + + +import logging +import os +import sys +import unittest + + +_logger = logging.getLogger(__name__) + + +def suite(): + + test_suite = unittest.TestSuite() + + if sys.platform.startswith('linux') and not os.environ.get('DISPLAY', ''): + # On Linux and no DISPLAY available (e.g., ssh without -X) + _logger.warning('silx.gui tests disabled (DISPLAY env. variable not set)') + + class SkipGUITest(unittest.TestCase): + def runTest(self): + self.skipTest( + 'silx.gui tests disabled (DISPLAY env. variable not set)') + + test_suite.addTest(SkipGUITest()) + return test_suite + + elif os.environ.get('WITH_QT_TEST', 'True') == 'False': + # Explicitly disabled tests + _logger.warning( + "silx.gui tests disabled (env. variable WITH_QT_TEST=False)") + + class SkipGUITest(unittest.TestCase): + def runTest(self): + self.skipTest( + "silx.gui tests disabled (env. variable WITH_QT_TEST=False)") + + test_suite.addTest(SkipGUITest()) + return test_suite + + # Import here to avoid loading QT if tests are disabled + + from ..plot import test as test_plot + from ..fit import test as test_fit + from ..hdf5 import test as test_hdf5 + from ..widgets import test as test_widgets + from ..data import test as test_data + from . import test_qt + # Console tests disabled due to corruption of python environment + # (see issue #538 on github) + # from . import test_console + from . import test_icons + from . import test_utils + + try: + from ..plot3d.test import suite as test_plot3d_suite + + except ImportError: + _logger.warning( + 'silx.gui.plot3d tests disabled ' + '(PyOpenGL or QtOpenGL not installed)') + + class SkipPlot3DTest(unittest.TestCase): + def runTest(self): + self.skipTest('silx.gui.plot3d tests disabled ' + '(PyOpenGL or QtOpenGL not installed)') + + test_plot3d_suite = SkipPlot3DTest + + + test_suite.addTest(test_qt.suite()) + test_suite.addTest(test_plot.suite()) + test_suite.addTest(test_fit.suite()) + test_suite.addTest(test_hdf5.suite()) + test_suite.addTest(test_widgets.suite()) + # test_suite.addTest(test_console.suite()) # see issue #538 on github + test_suite.addTest(test_icons.suite()) + test_suite.addTest(test_data.suite()) + test_suite.addTest(test_utils.suite()) + test_suite.addTest(test_plot3d_suite()) + return test_suite diff --git a/silx/gui/test/test_console.py b/silx/gui/test/test_console.py new file mode 100644 index 0000000..7c25372 --- /dev/null +++ b/silx/gui/test/test_console.py @@ -0,0 +1,91 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic tests for IPython console widget""" + +from __future__ import print_function + +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "05/12/2016" + + +import unittest + +from silx.gui.test.utils import TestCaseQt + +from silx.gui import qt +try: + from silx.gui.console import IPythonDockWidget +except ImportError: + console_missing = True +else: + console_missing = False + + +# dummy objects to test pushing variables to the interactive namespace +_a = 1 + + +def _f(): + print("Hello World!") + + +@unittest.skipIf(console_missing, "Could not import Ipython and/or qtconsole") +class TestConsole(TestCaseQt): + """Basic test for ``module.IPythonDockWidget``""" + + def setUp(self): + super(TestConsole, self).setUp() + self.console = IPythonDockWidget( + available_vars={"a": _a, "f": _f}, + custom_banner="Welcome!\n") + self.console.show() + self.qWaitForWindowExposed(self.console) + + def tearDown(self): + self.console.setAttribute(qt.Qt.WA_DeleteOnClose) + self.console.close() + del self.console + super(TestConsole, self).tearDown() + + def testShow(self): + pass + + def testInteract(self): + self.mouseClick(self.console, qt.Qt.LeftButton) + self.keyClicks(self.console, 'import silx') + self.keyClick(self.console, qt.Qt.Key_Enter) + self.qapp.processEvents() + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestConsole)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/test/test_icons.py b/silx/gui/test/test_icons.py new file mode 100644 index 0000000..f363c43 --- /dev/null +++ b/silx/gui/test/test_icons.py @@ -0,0 +1,116 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic test of Qt icons module.""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "26/04/2017" + + +import gc +import unittest +import weakref + +from silx.gui import qt +from silx.gui.test.utils import TestCaseQt +from silx.gui import icons + + +class TestIcons(TestCaseQt): + """Test to check that icons module.""" + + def testSvgIcon(self): + if "svg" not in qt.supportedImageFormats(): + self.skipTest("SVG not supported") + icon = icons.getQIcon("test-svg") + self.assertIsNotNone(icon) + + def testPngIcon(self): + icon = icons.getQIcon("test-png") + self.assertIsNotNone(icon) + + def testUnexistingIcon(self): + self.assertRaises(ValueError, icons.getQIcon, "not-exists") + + def testExistingQPixmap(self): + icon = icons.getQPixmap("crop") + self.assertIsNotNone(icon) + + def testUnexistingQPixmap(self): + self.assertRaises(ValueError, icons.getQPixmap, "not-exists") + + def testCache(self): + icon1 = icons.getQIcon("crop") + icon2 = icons.getQIcon("crop") + self.assertIs(icon1, icon2) + + def testCacheReleased(self): + icon = icons.getQIcon("crop") + icon_ref = weakref.ref(icon) + del icon + gc.collect() + self.assertIsNone(icon_ref()) + + +class TestAnimatedIcons(TestCaseQt): + """Test to check that icons module.""" + + def testProcessWorking(self): + icon = icons.getWaitIcon() + self.assertIsNotNone(icon) + + def testProcessWorkingCache(self): + icon1 = icons.getWaitIcon() + icon2 = icons.getWaitIcon() + self.assertIs(icon1, icon2) + + def testMovieIconExists(self): + if "mng" not in qt.supportedImageFormats(): + self.skipTest("MNG not supported") + icon = icons.MovieAnimatedIcon("process-working") + self.assertIsNotNone(icon) + + def testMovieIconNotExists(self): + self.assertRaises(ValueError, icons.MovieAnimatedIcon, "not-exists") + + def testMultiImageIconExists(self): + icon = icons.MultiImageAnimatedIcon("process-working") + self.assertIsNotNone(icon) + + def testMultiImageIconNotExists(self): + self.assertRaises(ValueError, icons.MultiImageAnimatedIcon, "not-exists") + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestIcons)) + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestAnimatedIcons)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/test/test_qt.py b/silx/gui/test/test_qt.py new file mode 100644 index 0000000..3a89a33 --- /dev/null +++ b/silx/gui/test/test_qt.py @@ -0,0 +1,144 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Basic test of Qt bindings wrapper.""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "05/12/2016" + + +import os.path +import unittest + +from silx.test.utils import temp_dir +from silx.gui.test.utils import TestCaseQt + +from silx.gui import qt + + +class TestQtWrapper(unittest.TestCase): + """Minimalistic test to check that Qt has been loaded.""" + + def testQObject(self): + """Test that QObject is there.""" + obj = qt.QObject() + self.assertTrue(obj is not None) + + +class TestLoadUi(TestCaseQt): + """Test loadUi function""" + + TEST_UI = """<?xml version="1.0" encoding="UTF-8"?> + <ui version="4.0"> + <class>MainWindow</class> + <widget class="QMainWindow" name="MainWindow"> + <property name="geometry"> + <rect> + <x>0</x> + <y>0</y> + <width>293</width> + <height>296</height> + </rect> + </property> + <property name="windowTitle"> + <string>Test loadUi</string> + </property> + <widget class="QWidget" name="centralwidget"> + <widget class="QPushButton" name="pushButton"> + <property name="geometry"> + <rect> + <x>10</x> + <y>10</y> + <width>89</width> + <height>27</height> + </rect> + </property> + <property name="text"> + <string>Button 1</string> + </property> + </widget> + <widget class="QPushButton" name="pushButton_2"> + <property name="geometry"> + <rect> + <x>10</x> + <y>50</y> + <width>89</width> + <height>27</height> + </rect> + </property> + <property name="text"> + <string>Button 2</string> + </property> + </widget> + </widget> + <widget class="QMenuBar" name="menubar"> + <property name="geometry"> + <rect> + <x>0</x> + <y>0</y> + <width>293</width> + <height>25</height> + </rect> + </property> + </widget> + <widget class="QStatusBar" name="statusbar"/> + </widget> + <resources/> + <connections/> + </ui> + """ + + def testLoadUi(self): + """Create a QMainWindow from an ui file""" + with temp_dir() as tmp: + uifile = os.path.join(tmp, "test.ui") + + # write file + with open(uifile, mode='w') as f: + f.write(self.TEST_UI) + + class TestMainWindow(qt.QMainWindow): + def __init__(self, parent=None): + super(TestMainWindow, self).__init__(parent) + qt.loadUi(uifile, self) + + testMainWindow = TestMainWindow() + testMainWindow.show() + self.qWaitForWindowExposed(testMainWindow) + + testMainWindow.setAttribute(qt.Qt.WA_DeleteOnClose) + testMainWindow.close() + + +def suite(): + test_suite = unittest.TestSuite() + for TestCaseCls in (TestQtWrapper, TestLoadUi): + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestCaseCls)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/test/test_utils.py b/silx/gui/test/test_utils.py new file mode 100644 index 0000000..4625969 --- /dev/null +++ b/silx/gui/test/test_utils.py @@ -0,0 +1,77 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Test of utils module.""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "16/01/2017" + + +import unittest + +import numpy + +from silx.gui import qt +from silx.gui.test.utils import TestCaseQt + +from silx.gui import _utils + + +class TestQImageConversion(TestCaseQt): + """Tests conversion of QImage to/from numpy array.""" + + def testConvertArrayToQImage(self): + """Test conversion of numpy array to QImage""" + image = numpy.ones((3, 3, 3), dtype=numpy.uint8) + qimage = _utils.convertArrayToQImage(image) + + self.assertEqual(qimage.height(), image.shape[0]) + self.assertEqual(qimage.width(), image.shape[1]) + self.assertEqual(qimage.format(), qt.QImage.Format_RGB888) + + color = qt.QColor(1, 1, 1).rgb() + self.assertEqual(qimage.pixel(1, 1), color) + + def testConvertQImageToArray(self): + """Test conversion of QImage to numpy array""" + qimage = qt.QImage(3, 3, qt.QImage.Format_RGB888) + qimage.fill(0x010101) + image = _utils.convertQImageToArray(qimage) + + self.assertEqual(qimage.height(), image.shape[0]) + self.assertEqual(qimage.width(), image.shape[1]) + self.assertEqual(image.shape[2], 3) + self.assertTrue(numpy.all(numpy.equal(image, 1))) + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase( + TestQImageConversion)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/test/utils.py b/silx/gui/test/utils.py new file mode 100644 index 0000000..50cf7bf --- /dev/null +++ b/silx/gui/test/utils.py @@ -0,0 +1,428 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Helper class to write Qt widget unittests.""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "11/04/2017" + + +import gc +import logging +import unittest +import time +import functools +import sys + +logging.basicConfig() +_logger = logging.getLogger(__name__) + +from silx.gui import qt + +if qt.BINDING == 'PySide': + from PySide.QtTest import QTest +elif qt.BINDING == 'PyQt5': + from PyQt5.QtTest import QTest +elif qt.BINDING == 'PyQt4': + from PyQt4.QtTest import QTest +else: + raise ImportError('Unsupported Qt bindings') + +# Qt4/Qt5 compatibility wrapper +if qt.BINDING in ('PySide', 'PyQt4'): + _logger.info("QTest.qWaitForWindowExposed not available," + + "using QTest.qWaitForWindowShown instead.") + + def qWaitForWindowExposed(window, timeout=None): + """Mimic QTest.qWaitForWindowExposed for Qt4.""" + QTest.qWaitForWindowShown(window) + return True +else: + qWaitForWindowExposed = QTest.qWaitForWindowExposed + + +def qWaitForWindowExposedAndActivate(window, timeout=None): + """Waits until the window is shown in the screen. + + It also activates the window and raises it. + + See QTest.qWaitForWindowExposed for details. + """ + if timeout is None: + result = qWaitForWindowExposed(window) + else: + result = qWaitForWindowExposed(window, timeout) + + if result: + # Makes sure window is active and on top + window.activateWindow() + window.raise_() + + return result + + +# Placeholder for QApplication +_qapp = None + + +class TestCaseQt(unittest.TestCase): + """Base class to write test for Qt stuff. + + It creates a QApplication before running the tests. + WARNING: The QApplication is shared by all tests, which might have side + effects. + + After each test, this class is checking for widgets remaining alive. + To allow some widgets to remain alive at the end of a test, set the + allowedLeakingWidgets attribute to the number of widgets that can remain + alive at the end of the test. + With PySide, this test is not run for now as it seems PySide + is leaking widgets internally. + + All keyboard and mouse event simulation methods call qWait(20) after + simulating the event (as QTest does on Mac OSX). + This was introduced to fix issues with continuous integration tests + running with Xvfb on Linux. + """ + + DEFAULT_TIMEOUT_WAIT = 100 + """Default timeout for qWait""" + + TIMEOUT_WAIT = 0 + """Extra timeout in millisecond to add to qSleep, qWait and + qWaitForWindowExposed. + + Intended purpose is for debugging, to add extra time to waits in order to + allow to view the tested widgets. + """ + + @classmethod + def exceptionHandler(cls, exceptionClass, exception, stack): + import traceback + message = (''.join(traceback.format_tb(stack))) + template = 'Traceback (most recent call last):\n{2}{0}: {1}' + message = template.format(exceptionClass.__name__, exception, message) + cls._exceptions.append(message) + + @classmethod + def setUpClass(cls): + """Makes sure Qt is inited""" + cls._oldExceptionHook = sys.excepthook + sys.excepthook = cls.exceptionHandler + + global _qapp + if _qapp is None: + # Makes sure a QApplication exists and do it once for all + _qapp = qt.QApplication.instance() or qt.QApplication([]) + + # Create/delate a QWidget to make sure init of QDesktopWidget + _dummyWidget = qt.QWidget() + _dummyWidget.setAttribute(qt.Qt.WA_DeleteOnClose) + _dummyWidget.show() + _dummyWidget.close() + _qapp.processEvents() + + @classmethod + def tearDownClass(cls): + sys.excepthook = cls._oldExceptionHook + + def setUp(self): + """Get the list of existing widgets.""" + self.allowedLeakingWidgets = 0 + self.__previousWidgets = self.qapp.allWidgets() + self.__class__._exceptions = [] + + def _currentTestSucceeded(self): + if hasattr(self, '_outcome'): + # For Python >= 3.4 + result = self.defaultTestResult() # these 2 methods have no side effects + self._feedErrorsToResult(result, self._outcome.errors) + else: + # For Python < 3.4 + result = getattr(self, '_outcomeForDoCleanups', self._resultForDoCleanups) + + error = self.id() in [case.id() for case, _ in result.errors] + failure = self.id() in [case.id() for case, _ in result.failures] + return not error and not failure + + def _checkForUnreleasedWidgets(self): + """Test fixture checking that no more widgets exists.""" + gc.collect() + + widgets = [widget for widget in self.qapp.allWidgets() + if widget not in self.__previousWidgets] + del self.__previousWidgets + + if qt.BINDING == 'PySide': + return # Do not test for leaking widgets with PySide + + allowedLeakingWidgets = self.allowedLeakingWidgets + self.allowedLeakingWidgets = 0 + + if widgets and len(widgets) <= allowedLeakingWidgets: + _logger.info( + '%s: %d remaining widgets after test' % (self.id(), + len(widgets))) + + if len(widgets) > allowedLeakingWidgets: + raise RuntimeError( + "Test ended with widgets alive: %s" % str(widgets)) + + def tearDown(self): + if len(self.__class__._exceptions) > 0: + messages = "\n".join(self.__class__._exceptions) + raise AssertionError("Exception occured in Qt thread:\n" + messages) + + if self._currentTestSucceeded(): + self._checkForUnreleasedWidgets() + + @property + def qapp(self): + """The QApplication currently running.""" + return qt.QApplication.instance() + + # Proxy to QTest + + Press = QTest.Press + """Key press action code""" + + Release = QTest.Release + """Key release action code""" + + Click = QTest.Click + """Key click action code""" + + QTest = property(lambda self: QTest, + doc="""The Qt QTest class from the used Qt binding.""") + + def keyClick(self, widget, key, modifier=qt.Qt.NoModifier, delay=-1): + """Simulate clicking a key. + + See QTest.keyClick for details. + """ + QTest.keyClick(widget, key, modifier, delay) + self.qWait(20) + + def keyClicks(self, widget, sequence, modifier=qt.Qt.NoModifier, delay=-1): + """Simulate clicking a sequence of keys. + + See QTest.keyClick for details. + """ + QTest.keyClicks(widget, sequence, modifier, delay) + self.qWait(20) + + def keyEvent(self, action, widget, key, + modifier=qt.Qt.NoModifier, delay=-1): + """Sends a Qt key event. + + See QTest.keyEvent for details. + """ + QTest.keyEvent(action, widget, key, modifier, delay) + self.qWait(20) + + def keyPress(self, widget, key, modifier=qt.Qt.NoModifier, delay=-1): + """Sends a Qt key press event. + + See QTest.keyPress for details. + """ + QTest.keyPress(widget, key, modifier, delay) + self.qWait(20) + + def keyRelease(self, widget, key, modifier=qt.Qt.NoModifier, delay=-1): + """Sends a Qt key release event. + + See QTest.keyRelease for details. + """ + QTest.keyRelease(widget, key, modifier, delay) + self.qWait(20) + + def mouseClick(self, widget, button, modifier=None, pos=None, delay=-1): + """Simulate clicking a mouse button. + + See QTest.mouseClick for details. + """ + if modifier is None: + modifier = qt.Qt.KeyboardModifiers() + pos = qt.QPoint(pos[0], pos[1]) if pos is not None else qt.QPoint() + QTest.mouseClick(widget, button, modifier, pos, delay) + self.qWait(20) + + def mouseDClick(self, widget, button, modifier=None, pos=None, delay=-1): + """Simulate double clicking a mouse button. + + See QTest.mouseDClick for details. + """ + if modifier is None: + modifier = qt.Qt.KeyboardModifiers() + pos = qt.QPoint(pos[0], pos[1]) if pos is not None else qt.QPoint() + QTest.mouseDClick(widget, button, modifier, pos, delay) + self.qWait(20) + + def mouseMove(self, widget, pos=None, delay=-1): + """Simulate moving the mouse. + + See QTest.mouseMove for details. + """ + pos = qt.QPoint(pos[0], pos[1]) if pos is not None else qt.QPoint() + QTest.mouseMove(widget, pos, delay) + self.qWait(20) + + def mousePress(self, widget, button, modifier=None, pos=None, delay=-1): + """Simulate pressing a mouse button. + + See QTest.mousePress for details. + """ + if modifier is None: + modifier = qt.Qt.KeyboardModifiers() + pos = qt.QPoint(pos[0], pos[1]) if pos is not None else qt.QPoint() + QTest.mousePress(widget, button, modifier, pos, delay) + self.qWait(20) + + def mouseRelease(self, widget, button, modifier=None, pos=None, delay=-1): + """Simulate releasing a mouse button. + + See QTest.mouseRelease for details. + """ + if modifier is None: + modifier = qt.Qt.KeyboardModifiers() + pos = qt.QPoint(pos[0], pos[1]) if pos is not None else qt.QPoint() + QTest.mouseRelease(widget, button, modifier, pos, delay) + self.qWait(20) + + def qSleep(self, ms): + """Sleep for ms milliseconds, blocking the execution of the test. + + See QTest.qSleep for details. + """ + QTest.qSleep(ms + self.TIMEOUT_WAIT) + + def qWait(self, ms=None): + """Waits for ms milliseconds, events will be processed. + + See QTest.qWait for details. + """ + if ms is None: + ms = self.DEFAULT_TIMEOUT_WAIT + + if qt.BINDING == 'PySide': + # PySide has no qWait, provide a replacement + timeout = int(ms) + endTimeMS = int(time.time() * 1000) + timeout + while timeout > 0: + self.qapp.processEvents(qt.QEventLoop.AllEvents, + maxtime=timeout) + timeout = endTimeMS - int(time.time() * 1000) + else: + QTest.qWait(ms + self.TIMEOUT_WAIT) + + def qWaitForWindowExposed(self, window, timeout=None): + """Waits until the window is shown in the screen. + + See QTest.qWaitForWindowExposed for details. + """ + result = qWaitForWindowExposedAndActivate(window, timeout) + + if self.TIMEOUT_WAIT: + QTest.qWait(self.TIMEOUT_WAIT) + + return result + + +class SignalListener(): + """Util to listen a Qt event and store parameters + """ + + def __init__(self): + self.__calls = [] + + def __call__(self, *args, **kargs): + self.__calls.append((args, kargs)) + + def clear(self): + """Clear stored data""" + self.__calls = [] + + def callCount(self): + """ + Returns how many times the listener was called. + + :rtype: int + """ + return len(self.__calls) + + def arguments(self, callIndex=None, argumentIndex=None): + """Returns positional arguments optionally filtered by call count id + or argument index. + + :param int callIndex: Index of the called data + :param int argumentIndex: Index of the positional argument. + """ + if callIndex is not None: + result = self.__calls[callIndex][0] + if argumentIndex is not None: + result = result[argumentIndex] + else: + result = [x[0] for x in self.__calls] + if argumentIndex is not None: + result = [x[argumentIndex] for x in result] + return result + + def karguments(self, callIndex=None, argumentName=None): + """Returns positional arguments optionally filtered by call count id + or name of the keyword argument. + + :param int callIndex: Index of the called data + :param int argumentName: Name of the keyword argument. + """ + if callIndex is not None: + result = self.__calls[callIndex][1] + if argumentName is not None: + result = result[argumentName] + else: + result = [x[1] for x in self.__calls] + if argumentName is not None: + result = [x[argumentName] for x in result] + return result + + def partial(self, *args, **kargs): + """Returns a new partial object which when called will behave like this + listener called with the positional arguments args and keyword + arguments keywords. If more arguments are supplied to the call, they + are appended to args. If additional keyword arguments are supplied, + they extend and override keywords. + """ + return functools.partial(self, *args, **kargs) + + +def getQToolButtonFromAction(action): + """Return a QToolButton corresponding to a QAction. + + :param QAction action: The QAction from which to get QToolButton. + :return: A QToolButton associated to action or None. + """ + for widget in action.associatedWidgets(): + if isinstance(widget, qt.QToolButton): + return widget + return None diff --git a/silx/gui/widgets/FrameBrowser.py b/silx/gui/widgets/FrameBrowser.py new file mode 100644 index 0000000..783a70a --- /dev/null +++ b/silx/gui/widgets/FrameBrowser.py @@ -0,0 +1,307 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module defines two main classes: + + - :class:`FrameBrowser`: a widget with 4 buttons (first, previous, next, + last) to browse between frames and a text entry to access a specific frame + by typing it's number) + - :class:`HorizontalSliderWithBrowser`: a FrameBrowser with an additional + slider. This class inherits :class:`qt.QAbstractSlider`. + +""" +from silx.gui import qt +from silx.gui import icons + +__authors__ = ["V.A. Sole", "P. Knobel"] +__license__ = "MIT" +__date__ = "16/01/2017" + + +class FrameBrowser(qt.QWidget): + """Frame browser widget, with 4 buttons/icons and a line edit to provide + a way of selecting a frame index in a stack of images. + + It can be used in more generic case to select an integer within a range. + + :param QWidget parent: Parent widget + :param int n: Number of frames. This will set the range + of frame indices to 0--n-1. + If None, the range is initialized to the default QSlider range (0--99).""" + sigIndexChanged = qt.pyqtSignal(object) + + def __init__(self, parent=None, n=None): + qt.QWidget.__init__(self, parent) + + # Use the font size as the icon size to avoid to create bigger buttons + fontMetric = self.fontMetrics() + iconSize = qt.QSize(fontMetric.height(), fontMetric.height()) + + self.mainLayout = qt.QHBoxLayout(self) + self.mainLayout.setContentsMargins(0, 0, 0, 0) + self.mainLayout.setSpacing(0) + self.firstButton = qt.QPushButton(self) + self.firstButton.setIcon(icons.getQIcon("first")) + self.firstButton.setIconSize(iconSize) + self.previousButton = qt.QPushButton(self) + self.previousButton.setIcon(icons.getQIcon("previous")) + self.previousButton.setIconSize(iconSize) + self._lineEdit = qt.QLineEdit(self) + + self._label = qt.QLabel(self) + self.nextButton = qt.QPushButton(self) + self.nextButton.setIcon(icons.getQIcon("next")) + self.nextButton.setIconSize(iconSize) + self.lastButton = qt.QPushButton(self) + self.lastButton.setIcon(icons.getQIcon("last")) + self.lastButton.setIconSize(iconSize) + + self.mainLayout.addWidget(self.firstButton) + self.mainLayout.addWidget(self.previousButton) + self.mainLayout.addWidget(self._lineEdit) + self.mainLayout.addWidget(self._label) + self.mainLayout.addWidget(self.nextButton) + self.mainLayout.addWidget(self.lastButton) + + if n is None: + first = qt.QSlider().minimum() + last = qt.QSlider().maximum() + else: + first, last = 0, n + + self._lineEdit.setFixedWidth(self._lineEdit.fontMetrics().width('%05d' % last)) + validator = qt.QIntValidator(first, last, self._lineEdit) + self._lineEdit.setValidator(validator) + self._lineEdit.setText("%d" % first) + self._label.setText("of %d" % last) + + self._index = first + """0-based index""" + + self.firstButton.clicked.connect(self._firstClicked) + self.previousButton.clicked.connect(self._previousClicked) + self.nextButton.clicked.connect(self._nextClicked) + self.lastButton.clicked.connect(self._lastClicked) + self._lineEdit.editingFinished.connect(self._textChangedSlot) + + def lineEdit(self): + """Returns the line edit provided by this widget. + + :rtype: qt.QLineEdit + """ + return self._lineEdit + + def limitWidget(self): + """Returns the widget displaying axes limits. + + :rtype: qt.QLabel + """ + return self._label + + def _firstClicked(self): + """Select first/lowest frame number""" + self._lineEdit.setText("%d" % self._lineEdit.validator().bottom()) + self._textChangedSlot() + + def _previousClicked(self): + """Select previous frame number""" + if self._index > self._lineEdit.validator().bottom(): + self._lineEdit.setText("%d" % (self._index - 1)) + self._textChangedSlot() + + def _nextClicked(self): + """Select next frame number""" + if self._index < (self._lineEdit.validator().top()): + self._lineEdit.setText("%d" % (self._index + 1)) + self._textChangedSlot() + + def _lastClicked(self): + """Select last/highest frame number""" + self._lineEdit.setText("%d" % self._lineEdit.validator().top()) + self._textChangedSlot() + + def _textChangedSlot(self): + """Select frame number typed in the line edit widget""" + txt = self._lineEdit.text() + if not len(txt): + self._lineEdit.setText("%d" % self._index) + return + new_value = int(txt) + if new_value == self._index: + return + ddict = { + "event": "indexChanged", + "old": self._index, + "new": new_value, + "id": id(self) + } + self._index = new_value + self.sigIndexChanged.emit(ddict) + + def setRange(self, first, last): + """Set minimum and maximum frame indices + Initialize the frame index to *first*. + Update the label text to *" limits: first, last"* + + :param int first: Minimum frame index + :param int last: Maximum frame index""" + return self.setLimits(first, last) + + def setLimits(self, first, last): + """Set minimum and maximum frame indices. + Initialize the frame index to *first*. + Update the label text to *" limits: first, last"* + + :param int first: Minimum frame index + :param int last: Maximum frame index""" + bottom = min(first, last) + top = max(first, last) + self._lineEdit.validator().setTop(top) + self._lineEdit.validator().setBottom(bottom) + self._index = bottom + self._lineEdit.setText("%d" % self._index) + self._label.setText(" limits: %d, %d " % (bottom, top)) + + def setNFrames(self, nframes): + """Set minimum=0 and maximum=nframes-1 frame numbers. + Initialize the frame index to 0. + Update the label text to *"1 of nframes"* + + :param int nframes: Number of frames""" + bottom = 0 + top = nframes - 1 + self._lineEdit.validator().setTop(top) + self._lineEdit.validator().setBottom(bottom) + self._index = bottom + self._lineEdit.setText("%d" % self._index) + # display 1-based index in label + self._label.setText(" %d of %d " % (self._index + 1, top + 1)) + + def getCurrentIndex(self): + """Get 0-based frame index + """ + return self._index + + def setValue(self, value): + """Set 0-based frame index + + :param int value: Frame number""" + self._lineEdit.setText("%d" % value) + self._textChangedSlot() + + +class HorizontalSliderWithBrowser(qt.QAbstractSlider): + """ + Slider widget combining a :class:`QSlider` and a :class:`FrameBrowser`. + + The data model is an integer within a range. + + The default value is the default :class:`QSlider` value (0), + and the default range is the default QSlider range (0 -- 99) + + The signal emitted when the value is changed is the usual QAbstractSlider + signal :attr:`valueChanged`. The signal carries the value (as an integer). + + :param QWidget parent: Optional parent widget + """ + sigIndexChanged = qt.pyqtSignal(object) + + def __init__(self, parent=None): + qt.QAbstractSlider.__init__(self, parent) + self.setOrientation(qt.Qt.Horizontal) + + self.mainLayout = qt.QHBoxLayout(self) + self.mainLayout.setContentsMargins(0, 0, 0, 0) + self.mainLayout.setSpacing(2) + + self._slider = qt.QSlider(self) + self._slider.setOrientation(qt.Qt.Horizontal) + + self._browser = FrameBrowser(self) + + self.mainLayout.addWidget(self._slider, 1) + self.mainLayout.addWidget(self._browser) + + self._slider.valueChanged[int].connect(self._sliderSlot) + self._browser.sigIndexChanged.connect(self._browserSlot) + + def lineEdit(self): + """Returns the line edit provided by this widget. + + :rtype: qt.QLineEdit + """ + return self._browser.lineEdit() + + def limitWidget(self): + """Returns the widget displaying axes limits. + + :rtype: qt.QLabel + """ + return self._browser.limitWidget() + + def setMinimum(self, value): + """Set minimum value + + :param int value: Minimum value""" + self._slider.setMinimum(value) + maximum = self._slider.maximum() + self._browser.setRange(value, maximum) + + def setMaximum(self, value): + """Set maximum value + + :param int value: Maximum value + """ + self._slider.setMaximum(value) + minimum = self._slider.minimum() + self._browser.setRange(minimum, value) + + def setRange(self, first, last): + """Set minimum/maximum values + + :param int first: Minimum value + :param int last: Maximum value""" + self._slider.setRange(first, last) + self._browser.setRange(first, last) + + def _sliderSlot(self, value): + """Emit selected value when slider is activated + """ + self._browser.setValue(value) + self.valueChanged.emit(value) + + def _browserSlot(self, ddict): + """Emit selected value when browser state is changed""" + self._slider.setValue(ddict['new']) + + def setValue(self, value): + """Set value + + :param int value: value""" + self._slider.setValue(value) + self._browser.setValue(value) + + def value(self): + """Get selected value""" + return self._slider.value() diff --git a/silx/gui/widgets/HierarchicalTableView.py b/silx/gui/widgets/HierarchicalTableView.py new file mode 100644 index 0000000..3ccf4c7 --- /dev/null +++ b/silx/gui/widgets/HierarchicalTableView.py @@ -0,0 +1,172 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" +This module define a hierarchical table view and model. + +It allows to define many headers in the middle of a table. + +The implementation hide the default header and allows to custom each cells +to became a header. + +Row and column span is a concept of the view in a QTableView. +This implementation also provide a span property as part of the model of the +cell. A role is define to custom this information. +The view is updated everytime the model is reset to take care of the +changes of this information. + +A default item delegate is used to redefine the paint of the cells. +""" +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "07/04/2017" + +from silx.gui import qt + + +class HierarchicalTableModel(qt.QAbstractTableModel): + """ + Abstract table model to provide more custom on row and column span and + headers. + + Default headers are ignored and each cells can define IsHeaderRole and + SpanRole using the `data` function. + """ + + SpanRole = qt.Qt.UserRole + 0 + """Role returning a tuple for number of row span then column span. + + None and (1, 1) are neutral for the rendering. + """ + + IsHeaderRole = qt.Qt.UserRole + 1 + """Role returning True is the identified cell is a header.""" + + UserRole = qt.Qt.UserRole + 2 + """First index of user defined roles""" + + def headerData(self, section, orientation, role=qt.Qt.DisplayRole): + """Returns the 0-based row or column index, for display in the + horizontal and vertical headers + + In this case the headers are just ignored. Header information is part + of each cells. + """ + return None + + +class HierarchicalItemDelegate(qt.QStyledItemDelegate): + """ + Delegate item to take care of the rendering of the default table cells and + also the header cells. + """ + + def __init__(self, parent=None): + """ + Constructor + + :param qt.QObject parent: Parent of the widget + """ + qt.QStyledItemDelegate.__init__(self, parent) + + def paint(self, painter, option, index): + """Override the paint function to inject the style of the header. + + :param qt.QPainter painter: Painter context used to displayed the cell + :param qt.QStyleOptionViewItem option: Control how the editor is shown + :param qt.QIndex index: Index of the data to display + """ + isHeader = index.data(role=HierarchicalTableModel.IsHeaderRole) + if isHeader: + span = index.data(role=HierarchicalTableModel.SpanRole) + span = 1 if span is None else span[1] + columnCount = index.model().columnCount() + if span == columnCount: + mainTitle = True + position = qt.QStyleOptionHeader.OnlyOneSection + else: + mainTitle = False + col = index.column() + if col == 0: + position = qt.QStyleOptionHeader.Beginning + elif col < columnCount - 1: + position = qt.QStyleOptionHeader.Middle + else: + position = qt.QStyleOptionHeader.End + opt = qt.QStyleOptionHeader() + opt.direction = option.direction + opt.text = index.data() + opt.textAlignment = qt.Qt.AlignCenter if mainTitle else qt.Qt.AlignVCenter + opt.direction = option.direction + opt.fontMetrics = option.fontMetrics + opt.palette = option.palette + opt.rect = option.rect + opt.state = option.state + opt.position = position + margin = -1 + style = qt.QApplication.instance().style() + opt.rect = opt.rect.adjusted(margin, margin, -margin, -margin) + style.drawControl(qt.QStyle.CE_HeaderSection, opt, painter, None) + margin = 3 + opt.rect = opt.rect.adjusted(margin, margin, -margin, -margin) + style.drawControl(qt.QStyle.CE_HeaderLabel, opt, painter, None) + else: + qt.QStyledItemDelegate.paint(self, painter, option, index) + + +class HierarchicalTableView(qt.QTableView): + """A TableView which allow to display a `HierarchicalTableModel`.""" + + def __init__(self, parent=None): + """ + Constructor + + :param qt.QWidget parent: Parent of the widget + """ + super(HierarchicalTableView, self).__init__(parent) + self.setItemDelegate(HierarchicalItemDelegate(self)) + self.verticalHeader().setVisible(False) + self.horizontalHeader().setVisible(False) + + def setModel(self, model): + """Override the default function to connect the model to update + function""" + if self.model() is not None: + model.modelReset.disconnect(self.__modelReset) + super(HierarchicalTableView, self).setModel(model) + if self.model() is not None: + model.modelReset.connect(self.__modelReset) + self.__modelReset() + + def __modelReset(self): + """Update the model to take care of the changes of the span + information""" + self.clearSpans() + model = self.model() + for row in range(model.rowCount()): + for column in range(model.columnCount()): + index = model.index(row, column, qt.QModelIndex()) + span = model.data(index, HierarchicalTableModel.SpanRole) + if span is not None and span != (1, 1): + self.setSpan(row, column, span[0], span[1]) diff --git a/silx/gui/widgets/MedianFilterDialog.py b/silx/gui/widgets/MedianFilterDialog.py new file mode 100644 index 0000000..3eddff3 --- /dev/null +++ b/silx/gui/widgets/MedianFilterDialog.py @@ -0,0 +1,74 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +""" MedianFilterDialog +Classes +------- + +Widgets: + + - :class:`MedianFilterDialog` +""" + +__authors__ = ["H. Payno"] +__license__ = "MIT" +__date__ = "14/02/2017" + +from silx.gui import qt + +class MedianFilterDialog(qt.QDialog): + """QDialog window featuring a :class:`BackgroundWidget`""" + sigFilterOptChanged = qt.Signal(int, bool) + + def __init__(self, parent=None): + qt.QDialog.__init__(self, parent) + + self.setWindowTitle("Median filter options") + self.mainLayout = qt.QHBoxLayout(self) + self.setLayout(self.mainLayout) + + # filter width GUI + self.mainLayout.addWidget(qt.QLabel('filter width:', parent = self)) + self._filterWidth = qt.QSpinBox(parent=self) + self._filterWidth.setMinimum(1) + self._filterWidth.setValue(1) + self._filterWidth.setSingleStep(2); + widthTooltip = """radius width of the pixel including in the filter + for each pixel""" + self._filterWidth.setToolTip(widthTooltip) + self._filterWidth.valueChanged.connect(self._filterOptionChanged) + self.mainLayout.addWidget(self._filterWidth) + + # filter option GUI + self._filterOption = qt.QCheckBox('conditional', parent=self) + conditionalTooltip = """if check, implement a conditional filter""" + self._filterOption.stateChanged.connect(self._filterOptionChanged) + self.mainLayout.addWidget(self._filterOption) + + def _filterOptionChanged(self): + """Call back used when the filter values are changed""" + if self._filterWidth.value()%2 == 0: + logging.warning('median filter only accept odd values') + else: + self.sigFilterOptChanged.emit(self._filterWidth.value(), self._filterOption.isChecked())
\ No newline at end of file diff --git a/silx/gui/widgets/PeriodicTable.py b/silx/gui/widgets/PeriodicTable.py new file mode 100644 index 0000000..2f1ca78 --- /dev/null +++ b/silx/gui/widgets/PeriodicTable.py @@ -0,0 +1,825 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Periodic table widgets + +Classes +------- + +Widgets: + + - :class:`PeriodicTable` + - :class:`PeriodicList` + - :class:`PeriodicCombo` + +Data model: + + - :class:`PeriodicTableItem` + - :class:`ColoredPeriodicTableItem` + + +Example of usage +---------------- + +This example uses the widgets with the standard builtin elements list. + +.. code-block:: python + + from silx.gui import qt + from silx.gui.widgets.PeriodicTable import PeriodicTable, \ + PeriodicCombo, PeriodicList + + a = qt.QApplication([]) + + w = qt.QTabWidget() + + ptable = PeriodicTable(w, selectable=True) + pcombo = PeriodicCombo(w) + plist = PeriodicList(w) + + w.addTab(ptable, "PeriodicTable") + w.addTab(plist, "PeriodicList") + w.addTab(pcombo, "PeriodicCombo") + + ptable.setSelection(['H', 'Fe', 'Si']) + plist.setSelectedElements(['H', 'Be', 'F']) + pcombo.setSelection("Li") + + def change_list(items): + print("New list selection:", [item.symbol for item in items]) + + def change_combo(item): + print("New combo selection:", item.symbol) + + def click_table(item): + print("New table click:", item.symbol) + + def change_table(items): + print("New table selection:", [item.symbol for item in items]) + + ptable.sigElementClicked.connect(click_table) + ptable.sigSelectionChanged.connect(change_table) + plist.sigSelectionChanged.connect(change_list) + pcombo.sigSelectionChanged.connect(change_combo) + + w.show() + a.exec_() + + +The second example explains how to define custom elements. + +.. code-block:: python + + from silx.gui import qt + from silx.gui.widgets.PeriodicTable import PeriodicTable, \ + PeriodicCombo, PeriodicList + from silx.gui.widgets.PeriodicTable import PeriodicTableItem + + # subclass PeriodicTableItem + class MyPeriodicTableItem(PeriodicTableItem): + "New item with added mass number and number of protons" + def __init__(self, symbol, Z, A, col, row, name, mass, + subcategory=""): + PeriodicTableItem.__init__( + self, symbol, Z, col, row, name, mass, + subcategory) + + self.A = A + "Mass number (neutrons + protons)" + + self.num_neutrons = A - Z + "Number of neutrons" + + # build your list of elements + my_elements = [MyPeriodicTableItem("H", 1, 1, 1, 1, "hydrogen", + 1.00800, "diatomic nonmetal"), + MyPeriodicTableItem("He", 2, 4, 18, 1, "helium", + 4.0030, "noble gas"), + # etc ... + ] + + app = qt.QApplication([]) + + ptable = PeriodicTable(elements=my_elements, selectable=True) + ptable.show() + + def click_table(item): + "Callback function printing the mass number of clicked element" + print("New table click, mass number:", item.A) + + ptable.sigElementClicked.connect(click_table) + app.exec_() + +""" + +__authors__ = ["E. Papillon", "V.A. Sole", "P. Knobel"] +__license__ = "MIT" +__date__ = "26/01/2017" + +from collections import OrderedDict +import logging +from silx.gui import qt + +_logger = logging.getLogger(__name__) + +# Symbol Atomic Number col row name mass subcategory +_elements = [("H", 1, 1, 1, "hydrogen", 1.00800, "diatomic nonmetal"), + ("He", 2, 18, 1, "helium", 4.0030, "noble gas"), + ("Li", 3, 1, 2, "lithium", 6.94000, "alkali metal"), + ("Be", 4, 2, 2, "beryllium", 9.01200, "alkaline earth metal"), + ("B", 5, 13, 2, "boron", 10.8110, "metalloid"), + ("C", 6, 14, 2, "carbon", 12.0100, "polyatomic nonmetal"), + ("N", 7, 15, 2, "nitrogen", 14.0080, "diatomic nonmetal"), + ("O", 8, 16, 2, "oxygen", 16.0000, "diatomic nonmetal"), + ("F", 9, 17, 2, "fluorine", 19.0000, "diatomic nonmetal"), + ("Ne", 10, 18, 2, "neon", 20.1830, "noble gas"), + ("Na", 11, 1, 3, "sodium", 22.9970, "alkali metal"), + ("Mg", 12, 2, 3, "magnesium", 24.3200, "alkaline earth metal"), + ("Al", 13, 13, 3, "aluminium", 26.9700, "post transition metal"), + ("Si", 14, 14, 3, "silicon", 28.0860, "metalloid"), + ("P", 15, 15, 3, "phosphorus", 30.9750, "polyatomic nonmetal"), + ("S", 16, 16, 3, "sulphur", 32.0660, "polyatomic nonmetal"), + ("Cl", 17, 17, 3, "chlorine", 35.4570, "diatomic nonmetal"), + ("Ar", 18, 18, 3, "argon", 39.9440, "noble gas"), + ("K", 19, 1, 4, "potassium", 39.1020, "alkali metal"), + ("Ca", 20, 2, 4, "calcium", 40.0800, "alkaline earth metal"), + ("Sc", 21, 3, 4, "scandium", 44.9600, "transition metal"), + ("Ti", 22, 4, 4, "titanium", 47.9000, "transition metal"), + ("V", 23, 5, 4, "vanadium", 50.9420, "transition metal"), + ("Cr", 24, 6, 4, "chromium", 51.9960, "transition metal"), + ("Mn", 25, 7, 4, "manganese", 54.9400, "transition metal"), + ("Fe", 26, 8, 4, "iron", 55.8500, "transition metal"), + ("Co", 27, 9, 4, "cobalt", 58.9330, "transition metal"), + ("Ni", 28, 10, 4, "nickel", 58.6900, "transition metal"), + ("Cu", 29, 11, 4, "copper", 63.5400, "transition metal"), + ("Zn", 30, 12, 4, "zinc", 65.3800, "transition metal"), + ("Ga", 31, 13, 4, "gallium", 69.7200, "post transition metal"), + ("Ge", 32, 14, 4, "germanium", 72.5900, "metalloid"), + ("As", 33, 15, 4, "arsenic", 74.9200, "metalloid"), + ("Se", 34, 16, 4, "selenium", 78.9600, "polyatomic nonmetal"), + ("Br", 35, 17, 4, "bromine", 79.9200, "diatomic nonmetal"), + ("Kr", 36, 18, 4, "krypton", 83.8000, "noble gas"), + ("Rb", 37, 1, 5, "rubidium", 85.4800, "alkali metal"), + ("Sr", 38, 2, 5, "strontium", 87.6200, "alkaline earth metal"), + ("Y", 39, 3, 5, "yttrium", 88.9050, "transition metal"), + ("Zr", 40, 4, 5, "zirconium", 91.2200, "transition metal"), + ("Nb", 41, 5, 5, "niobium", 92.9060, "transition metal"), + ("Mo", 42, 6, 5, "molybdenum", 95.9500, "transition metal"), + ("Tc", 43, 7, 5, "technetium", 99.0000, "transition metal"), + ("Ru", 44, 8, 5, "ruthenium", 101.0700, "transition metal"), + ("Rh", 45, 9, 5, "rhodium", 102.9100, "transition metal"), + ("Pd", 46, 10, 5, "palladium", 106.400, "transition metal"), + ("Ag", 47, 11, 5, "silver", 107.880, "transition metal"), + ("Cd", 48, 12, 5, "cadmium", 112.410, "transition metal"), + ("In", 49, 13, 5, "indium", 114.820, "post transition metal"), + ("Sn", 50, 14, 5, "tin", 118.690, "post transition metal"), + ("Sb", 51, 15, 5, "antimony", 121.760, "metalloid"), + ("Te", 52, 16, 5, "tellurium", 127.600, "metalloid"), + ("I", 53, 17, 5, "iodine", 126.910, "diatomic nonmetal"), + ("Xe", 54, 18, 5, "xenon", 131.300, "noble gas"), + ("Cs", 55, 1, 6, "caesium", 132.910, "alkali metal"), + ("Ba", 56, 2, 6, "barium", 137.360, "alkaline earth metal"), + ("La", 57, 3, 6, "lanthanum", 138.920, "lanthanide"), + ("Ce", 58, 4, 9, "cerium", 140.130, "lanthanide"), + ("Pr", 59, 5, 9, "praseodymium", 140.920, "lanthanide"), + ("Nd", 60, 6, 9, "neodymium", 144.270, "lanthanide"), + ("Pm", 61, 7, 9, "promethium", 147.000, "lanthanide"), + ("Sm", 62, 8, 9, "samarium", 150.350, "lanthanide"), + ("Eu", 63, 9, 9, "europium", 152.000, "lanthanide"), + ("Gd", 64, 10, 9, "gadolinium", 157.260, "lanthanide"), + ("Tb", 65, 11, 9, "terbium", 158.930, "lanthanide"), + ("Dy", 66, 12, 9, "dysprosium", 162.510, "lanthanide"), + ("Ho", 67, 13, 9, "holmium", 164.940, "lanthanide"), + ("Er", 68, 14, 9, "erbium", 167.270, "lanthanide"), + ("Tm", 69, 15, 9, "thulium", 168.940, "lanthanide"), + ("Yb", 70, 16, 9, "ytterbium", 173.040, "lanthanide"), + ("Lu", 71, 17, 9, "lutetium", 174.990, "lanthanide"), + ("Hf", 72, 4, 6, "hafnium", 178.500, "transition metal"), + ("Ta", 73, 5, 6, "tantalum", 180.950, "transition metal"), + ("W", 74, 6, 6, "tungsten", 183.920, "transition metal"), + ("Re", 75, 7, 6, "rhenium", 186.200, "transition metal"), + ("Os", 76, 8, 6, "osmium", 190.200, "transition metal"), + ("Ir", 77, 9, 6, "iridium", 192.200, "transition metal"), + ("Pt", 78, 10, 6, "platinum", 195.090, "transition metal"), + ("Au", 79, 11, 6, "gold", 197.200, "transition metal"), + ("Hg", 80, 12, 6, "mercury", 200.610, "transition metal"), + ("Tl", 81, 13, 6, "thallium", 204.390, "post transition metal"), + ("Pb", 82, 14, 6, "lead", 207.210, "post transition metal"), + ("Bi", 83, 15, 6, "bismuth", 209.000, "post transition metal"), + ("Po", 84, 16, 6, "polonium", 209.000, "post transition metal"), + ("At", 85, 17, 6, "astatine", 210.000, "metalloid"), + ("Rn", 86, 18, 6, "radon", 222.000, "noble gas"), + ("Fr", 87, 1, 7, "francium", 223.000, "alkali metal"), + ("Ra", 88, 2, 7, "radium", 226.000, "alkaline earth metal"), + ("Ac", 89, 3, 7, "actinium", 227.000, "actinide"), + ("Th", 90, 4, 10, "thorium", 232.000, "actinide"), + ("Pa", 91, 5, 10, "proactinium", 231.03588, "actinide"), + ("U", 92, 6, 10, "uranium", 238.070, "actinide"), + ("Np", 93, 7, 10, "neptunium", 237.000, "actinide"), + ("Pu", 94, 8, 10, "plutonium", 239.100, "actinide"), + ("Am", 95, 9, 10, "americium", 243, "actinide"), + ("Cm", 96, 10, 10, "curium", 247, "actinide"), + ("Bk", 97, 11, 10, "berkelium", 247, "actinide"), + ("Cf", 98, 12, 10, "californium", 251, "actinide"), + ("Es", 99, 13, 10, "einsteinium", 252, "actinide"), + ("Fm", 100, 14, 10, "fermium", 257, "actinide"), + ("Md", 101, 15, 10, "mendelevium", 258, "actinide"), + ("No", 102, 16, 10, "nobelium", 259, "actinide"), + ("Lr", 103, 17, 10, "lawrencium", 262, "actinide"), + ("Rf", 104, 4, 7, "rutherfordium", 261, "transition metal"), + ("Db", 105, 5, 7, "dubnium", 262, "transition metal"), + ("Sg", 106, 6, 7, "seaborgium", 266, "transition metal"), + ("Bh", 107, 7, 7, "bohrium", 264, "transition metal"), + ("Hs", 108, 8, 7, "hassium", 269, "transition metal"), + ("Mt", 109, 9, 7, "meitnerium", 268)] + + +class PeriodicTableItem(object): + """Periodic table item, used as generic item in :class:`PeriodicTable`, + :class:`PeriodicCombo` and :class:`PeriodicList`. + + This implementation stores the minimal amount of information needed by the + widgets: + + - atomic symbol + - atomic number + - element name + - atomic mass + - column of element in periodic table + - row of element in periodic table + + You can subclass this class to add additional information. + + :param str symbol: Atomic symbol (e.g. H, He, Li...) + :param int Z: Proton number + :param int col: 1-based column index of element in periodic table + :param int row: 1-based row index of element in periodic table + :param str name: PeriodicTableItem name ("hydrogen", ...) + :param float mass: Atomic mass (gram per mol) + :param str subcategory: Subcategory, based on physical properties + (e.g. "alkali metal", "noble gas"...) + """ + def __init__(self, symbol, Z, col, row, name, mass, + subcategory=""): + self.symbol = symbol + """Atomic symbol (e.g. H, He, Li...)""" + self.Z = Z + """Atomic number (Proton number)""" + self.col = col + """1-based column index of element in periodic table""" + self.row = row + """1-based row index of element in periodic table""" + self.name = name + """PeriodicTableItem name ("hydrogen", ...)""" + self.mass = mass + """Atomic mass (gram per mol)""" + self.subcategory = subcategory + """Subcategory, based on physical properties + (e.g. "alkali metal", "noble gas"...)""" + + # pymca compatibility (elements used to be stored as a list of lists) + def __getitem__(self, idx): + if idx == 6: + _logger.warning("density not implemented in silx, returning 0.") + + ret = [self.symbol, self.Z, + self.col, self.row, + self.name, self.mass, + 0.] + return ret[idx] + + def __len__(self): + return 6 + + +class ColoredPeriodicTableItem(PeriodicTableItem): + """:class:`PeriodicTableItem` with an added :attr:`bgcolor`. + The background color can be passed as a parameter to the constructor. + If it is not specified, it will be defined based on + :attr:`subcategory`. + + :param str bgcolor: Custom background color for element in + periodic table, as a RGB string *#RRGGBB*""" + COLORS = { + "diatomic nonmetal": "#7FFF00", # chartreuse + "noble gas": "#00FFFF", # cyan + "alkali metal": "#FFE4B5", # Moccasin + "alkaline earth metal": "#FFA500", # orange + "polyatomic nonmetal": "#7FFFD4", # aquamarine + "transition metal": "#FFA07A", # light salmon + "metalloid": "#8FBC8F", # Dark Sea Green + "post transition metal": "#D3D3D3", # light gray + "lanthanide": "#FFB6C1", # light pink + "actinide": "#F08080", # Light Coral + "": "#FFFFFF" # white + } + """Dictionary defining RGB colors for each subcategory.""" + + def __init__(self, symbol, Z, col, row, name, mass, + subcategory="", bgcolor=None): + PeriodicTableItem.__init__(self, symbol, Z, col, row, name, mass, + subcategory) + + self.bgcolor = self.COLORS.get(subcategory, "#FFFFFF") + """Background color of element in the periodic table, + based on its subcategory. This should be a string of a hexadecimal + RGB code, with the format *#RRGGBB*. + If the subcategory is unknown, use white (*#FFFFFF*) + """ + + # possible custom color + if bgcolor is not None: + self.bgcolor = bgcolor + + +_defaultTableItems = [ColoredPeriodicTableItem(*info) for info in _elements] + + +class _ElementButton(qt.QPushButton): + """Atomic element button, used as a cell in the periodic table + """ + sigElementEnter = qt.pyqtSignal(object) + """Signal emitted as the cursor enters the widget""" + sigElementLeave = qt.pyqtSignal(object) + """Signal emitted as the cursor leaves the widget""" + sigElementClicked = qt.pyqtSignal(object) + """Signal emitted when the widget is clicked""" + + def __init__(self, item, parent=None): + """ + + :param parent: Parent widget + :param PeriodicTableItem item: :class:`PeriodicTableItem` object + """ + qt.QPushButton.__init__(self, parent) + + self.item = item + """:class:`PeriodicTableItem` object represented by this button""" + + self.setText(item.symbol) + self.setFlat(1) + self.setCheckable(0) + + self.setSizePolicy(qt.QSizePolicy(qt.QSizePolicy.Expanding, + qt.QSizePolicy.Expanding)) + + self.selected = False + self.current = False + + # selection colors + self.selected_color = qt.QColor(qt.Qt.yellow) + self.current_color = qt.QColor(qt.Qt.gray) + self.selected_current_color = qt.QColor(qt.Qt.darkYellow) + + # element colors + + if hasattr(item, "bgcolor"): + self.bgcolor = qt.QColor(item.bgcolor) + else: + self.bgcolor = qt.QColor("#FFFFFF") + + self.brush = qt.QBrush() + self.__setBrush() + + self.clicked.connect(self.clickedSlot) + + def sizeHint(self): + return qt.QSize(40, 40) + + def setCurrent(self, b): + """Set this element button as current. + Multiple buttons can be selected. + + :param b: boolean + """ + self.current = b + self.__setBrush() + + def isCurrent(self): + """ + :return: True if element button is current + """ + return self.current + + def isSelected(self): + """ + :return: True if element button is selected + """ + return self.selected + + def setSelected(self, b): + """Set this element button as selected. + Only a single button can be selected. + + :param b: boolean + """ + self.selected = b + self.__setBrush() + + def __setBrush(self): + """Selected cells are yellow when not current. + The current cell is dark yellow when selected or grey when not + selected. + Other cells have no bg color by default, unless specified at + instantiation (:attr:`bgcolor`)""" + palette = self.palette() + # if self.current and self.selected: + # self.brush = qt.QBrush(self.selected_current_color) + # el + if self.selected: + self.brush = qt.QBrush(self.selected_color) + # elif self.current: + # self.brush = qt.QBrush(self.current_color) + elif self.bgcolor is not None: + self.brush = qt.QBrush(self.bgcolor) + else: + self.brush = qt.QBrush() + palette.setBrush(self.backgroundRole(), + self.brush) + self.setPalette(palette) + self.update() + + def paintEvent(self, pEvent): + # get button geometry + widgGeom = self.rect() + paintGeom = qt.QRect(widgGeom.left() + 1, + widgGeom.top() + 1, + widgGeom.width() - 2, + widgGeom.height() - 2) + + # paint background color + painter = qt.QPainter(self) + if self.brush is not None: + painter.fillRect(paintGeom, self.brush) + # paint frame + pen = qt.QPen(qt.Qt.black) + pen.setWidth(1 if not self.isCurrent() else 5) + painter.setPen(pen) + painter.drawRect(paintGeom) + painter.end() + qt.QPushButton.paintEvent(self, pEvent) + + def enterEvent(self, e): + """Emit a :attr:`sigElementEnter` signal and send a + :class:`PeriodicTableItem` object""" + self.sigElementEnter.emit(self.item) + + def leaveEvent(self, e): + """Emit a :attr:`sigElementLeave` signal and send a + :class:`PeriodicTableItem` object""" + self.sigElementLeave.emit(self.item) + + def clickedSlot(self): + """Emit a :attr:`sigElementClicked` signal and send a + :class:`PeriodicTableItem` object""" + self.sigElementClicked.emit(self.item) + + +class PeriodicTable(qt.QWidget): + """Periodic Table widget + + The following example shows how to connect clicking to selection:: + + from silx.gui import qt + from silx.gui.widgets.PeriodicTable import PeriodicTable + app = qt.QApplication([]) + pt = PeriodicTable() + pt.sigElementClicked.connect(pt.elementToggle) + pt.show() + app.exec_() + + To print all selected elements each time a new element is selected:: + + def my_slot(item): + pt.elementToggle(item) + selected_elements = pt.getSelection() + for e in selected_elements: + print(e.symbol) + + pt.sigElementClicked.connect(my_slot) + + """ + sigElementClicked = qt.pyqtSignal(object) + """When any element is clicked in the table, the widget emits + this signal and sends a :class:`PeriodicTableItem` object. + """ + + sigSelectionChanged = qt.pyqtSignal(object) + """When any element is selected/unselected in the table, the widget emits + this signal and sends a list of :class:`PeriodicTableItem` objects. + + .. note:: + + To enable selection of elements, you must set *selectable=True* + when you instantiate the widget. Alternatively, you can also connect + :attr:`sigElementClicked` to :meth:`elementToggle` manually:: + + pt = PeriodicTable() + pt.sigElementClicked.connect(pt.elementToggle) + + + :param parent: parent QWidget + :param str name: Widget window title + :param elements: List of items (:class:`PeriodicTableItem` objects) to + be represented in the table. By default, take elements from + a predefined list with minimal information (symbol, atomic number, + name, mass). + :param bool selectable: If *True*, multiple elements can be + selected by clicking with the mouse. If *False* (default), + selection is only possible with method :meth:`setSelection`. + """ + + def __init__(self, parent=None, name="PeriodicTable", elements=None, + selectable=False): + self.selectable = selectable + qt.QWidget.__init__(self, parent) + self.setWindowTitle(name) + self.gridLayout = qt.QGridLayout(self) + self.gridLayout.setContentsMargins(0, 0, 0, 0) + self.gridLayout.addItem(qt.QSpacerItem(0, 5), 7, 0) + + for idx in range(10): + self.gridLayout.setRowStretch(idx, 3) + # row 8 (above lanthanoids is empty) + self.gridLayout.setRowStretch(7, 2) + + # Element information displayed when cursor enters a cell + self.eltLabel = qt.QLabel(self) + f = self.eltLabel.font() + f.setBold(1) + self.eltLabel.setFont(f) + self.eltLabel.setAlignment(qt.Qt.AlignHCenter) + self.gridLayout.addWidget(self.eltLabel, 1, 1, 3, 10) + + self._eltCurrent = None + """Current :class:`_ElementButton` (last clicked)""" + + self._eltButtons = OrderedDict() + """Dictionary of all :class:`_ElementButton`. Keys are the symbols + ("H", "He", "Li"...)""" + + if elements is None: + elements = _defaultTableItems + # fill cells with elements + for elmt in elements: + self.__addElement(elmt) + + def __addElement(self, elmt): + """Add one :class:`_ElementButton` widget into the grid, + connect its signals to interact with the cursor""" + b = _ElementButton(elmt, self) + b.setAutoDefault(False) + + self._eltButtons[elmt.symbol] = b + self.gridLayout.addWidget(b, elmt.row, elmt.col) + + b.sigElementEnter.connect(self.elementEnter) + b.sigElementLeave.connect(self._elementLeave) + b.sigElementClicked.connect(self._elementClicked) + + def elementEnter(self, item): + """Update label with element info (e.g. "Nb(41) - niobium") + when mouse cursor hovers an element. + + :param PeriodicTableItem item: Element entered by cursor + """ + self.eltLabel.setText("%s(%d) - %s" % (item.symbol, item.Z, item.name)) + + def _elementLeave(self, item): + """Clear label when the cursor leaves the cell + + :param PeriodicTableItem item: Element left + """ + self.eltLabel.setText("") + + def _elementClicked(self, item): + """Emit :attr:`sigElementClicked`, + toggle selected state of element + + :param PeriodicTableItem item: Element clicked + """ + if self._eltCurrent is not None: + self._eltCurrent.setCurrent(False) + self._eltButtons[item.symbol].setCurrent(True) + self._eltCurrent = self._eltButtons[item.symbol] + if self.selectable: + self.elementToggle(item) + self.sigElementClicked.emit(item) + + def getSelection(self): + """Return a list of selected elements, as a list of :class:`PeriodicTableItem` + objects. + + :return: Selected items + :rtype: list(PeriodicTableItem) + """ + return [b.item for b in self._eltButtons.values() if b.isSelected()] + + def setSelection(self, symbols): + """Set selected elements. + + This causes the sigSelectionChanged signal + to be emitted, even if the selection didn't actually change. + + :param list(str) symbols: List of symbols of elements to be selected + (e.g. *["Fe", "Hg", "Li"]*) + """ + # accept list of PeriodicTableItems as input, because getSelection + # returns these objects and it makes sense to have getter and setter + # use same type of data + if isinstance(symbols[0], PeriodicTableItem): + symbols = [elmt.symbol for elmt in symbols] + + for (e, b) in self._eltButtons.items(): + b.setSelected(e in symbols) + self.sigSelectionChanged.emit(self.getSelection()) + + def setElementSelected(self, symbol, state): + """Modify *selected* status of a single element (select or unselect) + + :param str symbol: PeriodicTableItem symbol to be selected + :param bool state: *True* to select, *False* to unselect + """ + self._eltButtons[symbol].setSelected(state) + self.sigSelectionChanged.emit(self.getSelection()) + + def isElementSelected(self, symbol): + """Return *True* if element is selected, else *False* + + :param str symbol: PeriodicTableItem symbol + :return: *True* if element is selected, else *False* + """ + return self._eltButtons[symbol].isSelected() + + def elementToggle(self, item): + """Toggle selected/unselected state for element + + :param item: PeriodicTableItem object + """ + b = self._eltButtons[item.symbol] + b.setSelected(not b.isSelected()) + self.sigSelectionChanged.emit(self.getSelection()) + + +class PeriodicCombo(qt.QComboBox): + """ + Combo list with all atomic elements of the periodic table + + :param bool detailed: True (default) display element symbol, Z and name. + False display only element symbol and Z. + :param elements: List of items (:class:`PeriodicTableItem` objects) to + be represented in the table. By default, take elements from + a predefined list with minimal information (symbol, atomic number, + name, mass). + """ + sigSelectionChanged = qt.pyqtSignal(object) + """Signal emitted when the selection changes. Send + :class:`PeriodicTableItem` object representing selected + element + """ + + def __init__(self, parent=None, detailed=True, elements=None): + qt.QComboBox.__init__(self, parent) + + # add all elements from global list + if elements is None: + elements = _defaultTableItems + for i, elmt in enumerate(elements): + if detailed: + txt = "%2s (%d) - %s" % (elmt.symbol, elmt.Z, elmt.name) + else: + txt = "%2s (%d)" % (elmt.symbol, elmt.Z) + self.insertItem(i, txt) + + self.currentIndexChanged[int].connect(self.__selectionChanged) + + def __selectionChanged(self, idx): + """Emit :attr:`sigSelectionChanged`""" + self.sigSelectionChanged.emit(_defaultTableItems[idx]) + + def getSelection(self): + """Get selected element + + :return: Selected element + :rtype: PeriodicTableItem + """ + return _defaultTableItems[self.currentIndex()] + + def setSelection(self, symbol): + """Set selected item in combobox by giving the atomic symbol + + :param symbol: Symbol of element to be selected + """ + # accept PeriodicTableItem for getter/setter consistency + if isinstance(symbol, PeriodicTableItem): + symbol = symbol.symbol + symblist = [elmt.symbol for elmt in _defaultTableItems] + self.setCurrentIndex(symblist.index(symbol)) + + +class PeriodicList(qt.QTreeWidget): + """List of atomic elements in a :class:`QTreeView` + + :param QWidget parent: Parent widget + :param bool detailed: True (default) display element symbol, Z and name. + False display only element symbol and Z. + :param single: *True* for single element selection with mouse click, + *False* for multiple element selection mode. + """ + sigSelectionChanged = qt.pyqtSignal(object) + """When any element is selected/unselected in the widget, it emits + this signal and sends a list of currently selected + :class:`PeriodicTableItem` objects. + """ + + def __init__(self, parent=None, detailed=True, single=False, elements=None): + qt.QTreeWidget.__init__(self, parent) + + self.detailed = detailed + + headers = ["Z", "Symbol"] + if detailed: + headers.append("Name") + self.setColumnCount(3) + else: + self.setColumnCount(2) + self.setHeaderLabels(headers) + self.header().setStretchLastSection(False) + + self.setRootIsDecorated(0) + self.itemClicked.connect(self.__selectionChanged) + self.setSelectionMode(qt.QAbstractItemView.SingleSelection if single + else qt.QAbstractItemView.ExtendedSelection) + self.__fill_widget(elements) + self.resizeColumnToContents(0) + self.resizeColumnToContents(1) + if detailed: + self.resizeColumnToContents(2) + + def __fill_widget(self, elements): + """Fill tree widget with elements """ + if elements is None: + elements = _defaultTableItems + + self.tree_items = [] + + previous_item = None + for elmt in elements: + if previous_item is None: + item = qt.QTreeWidgetItem(self) + else: + item = qt.QTreeWidgetItem(self, previous_item) + item.setText(0, str(elmt.Z)) + item.setText(1, elmt.symbol) + if self.detailed: + item.setText(2, elmt.name) + self.tree_items.append(item) + previous_item = item + + def __selectionChanged(self, treeItem, column): + """Emit a :attr:`sigSelectionChanged` and send a list of + :class:`PeriodicTableItem` objects.""" + self.sigSelectionChanged.emit(self.getSelection()) + + def getSelection(self): + """Get a list of selected elements, as a list of :class:`PeriodicTableItem` + objects. + + :return: Selected elements + :rtype: list(PeriodicTableItem)""" + return [_defaultTableItems[idx] for idx in range(len(self.tree_items)) + if self.tree_items[idx].isSelected()] + + # setSelection is a bad name (name of a QTreeWidget method) + def setSelectedElements(self, symbolList): + """ + + :param symbolList: List of atomic symbols ["H", "He", "Li"...] + to be selected in the widget + """ + # accept PeriodicTableItem for getter/setter consistency + if isinstance(symbolList[0], PeriodicTableItem): + symbolList = [elmt.symbol for elmt in symbolList] + for idx in range(len(self.tree_items)): + self.tree_items[idx].setSelected(_defaultTableItems[idx].symbol in symbolList) diff --git a/silx/gui/widgets/TableWidget.py b/silx/gui/widgets/TableWidget.py new file mode 100644 index 0000000..fad80ee --- /dev/null +++ b/silx/gui/widgets/TableWidget.py @@ -0,0 +1,488 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides table widgets handling cut, copy and paste for +multiple cell selections. These actions can be triggered using keyboard +shortcuts or through a context menu (right-click). + +:class:`TableView` is a subclass of :class:`QTableView`. The added features +are made available to users after a model is added to the widget, using +:meth:`TableView.setModel`. + +:class:`TableWidget` is a subclass of :class:`qt.QTableWidget`, a table view +with a built-in standard data model. The added features are available as soon as +the widget is initialized. + +The cut, copy and paste actions are implemented as QActions: + + - :class:`CopySelectedCellsAction` (*Ctrl+C*) + - :class:`CopyAllCellsAction` + - :class:`CutSelectedCellsAction` (*Ctrl+X*) + - :class:`CutAllCellsAction` + - :class:`PasteCellsAction` (*Ctrl+V*) + +The copy actions are enabled by default. The cut and paste actions must be +explicitly enabled, by passing parameters ``cut=True, paste=True`` when +creating the widgets, or later by calling their :meth:`enableCut` and +:meth:`enablePaste` methods. +""" + +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "26/01/2017" + + +import sys +from .. import qt + + +if sys.platform.startswith("win"): + row_separator = "\r\n" +else: + row_separator = "\n" + +col_separator = "\t" + + +class CopySelectedCellsAction(qt.QAction): + """QAction to copy text from selected cells in a :class:`QTableWidget` + into the clipboard. + + If multiple cells are selected, the copied text will be a concatenation + of the texts in all selected cells, tabulated with tabulation and + newline characters. + + If the cells are sparsely selected, the structure is preserved by + representing the unselected cells as empty strings in between two + tabulation characters. + Beware of pasting this data in another table widget, because depending + on how the paste is implemented, the empty cells may cause data in the + target table to be deleted, even though you didn't necessarily select the + corresponding cell in the origin table. + + :param table: :class:`QTableView` to which this action belongs. + """ + def __init__(self, table): + if not isinstance(table, qt.QTableView): + raise ValueError('CopySelectedCellsAction must be initialised ' + + 'with a QTableWidget.') + super(CopySelectedCellsAction, self).__init__(table) + self.setText("Copy selection") + self.setToolTip("Copy selected cells into the clipboard.") + self.setShortcut(qt.QKeySequence.Copy) + self.setShortcutContext(qt.Qt.WidgetShortcut) + self.triggered.connect(self.copyCellsToClipboard) + self.table = table + self.cut = False + """:attr:`cut` can be set to True by classes inheriting this action, + to do a cut action.""" + + def copyCellsToClipboard(self): + """Concatenate the text content of all selected cells into a string + using tabulations and newlines to keep the table structure. + Put this text into the clipboard. + """ + selected_idx = self.table.selectedIndexes() + selected_idx_tuples = [(idx.row(), idx.column()) for idx in selected_idx] + + selected_rows = [idx[0] for idx in selected_idx_tuples] + selected_columns = [idx[1] for idx in selected_idx_tuples] + + data_model = self.table.model() + + copied_text = "" + for row in range(min(selected_rows), max(selected_rows) + 1): + for col in range(min(selected_columns), max(selected_columns) + 1): + index = data_model.index(row, col) + cell_text = data_model.data(index) + flags = data_model.flags(index) + + if (row, col) in selected_idx_tuples and cell_text is not None: + copied_text += cell_text + if self.cut and (flags & qt.Qt.ItemIsEditable): + data_model.setData(index, "") + copied_text += col_separator + # remove the right-most tabulation + copied_text = copied_text[:-len(col_separator)] + # add a newline + copied_text += row_separator + # remove final newline + copied_text = copied_text[:-len(row_separator)] + + # put this text into clipboard + qapp = qt.QApplication.instance() + qapp.clipboard().setText(copied_text) + + +class CopyAllCellsAction(qt.QAction): + """QAction to copy text from all cells in a :class:`QTableWidget` + into the clipboard. + + The copied text will be a concatenation + of the texts in all cells, tabulated with tabulation and + newline characters. + + :param table: :class:`QTableView` to which this action belongs. + """ + def __init__(self, table): + if not isinstance(table, qt.QTableView): + raise ValueError('CopyAllCellsAction must be initialised ' + + 'with a QTableWidget.') + super(CopyAllCellsAction, self).__init__(table) + self.setText("Copy all") + self.setToolTip("Copy all cells into the clipboard.") + self.triggered.connect(self.copyCellsToClipboard) + self.table = table + self.cut = False + + def copyCellsToClipboard(self): + """Concatenate the text content of all cells into a string + using tabulations and newlines to keep the table structure. + Put this text into the clipboard. + """ + data_model = self.table.model() + copied_text = "" + for row in range(data_model.rowCount()): + for col in range(data_model.columnCount()): + index = data_model.index(row, col) + cell_text = data_model.data(index) + flags = data_model.flags(index) + if cell_text is not None: + copied_text += cell_text + if self.cut and (flags & qt.Qt.ItemIsEditable): + data_model.setData(index, "") + copied_text += col_separator + # remove the right-most tabulation + copied_text = copied_text[:-len(col_separator)] + # add a newline + copied_text += row_separator + # remove final newline + copied_text = copied_text[:-len(row_separator)] + + # put this text into clipboard + qapp = qt.QApplication.instance() + qapp.clipboard().setText(copied_text) + + +class CutSelectedCellsAction(CopySelectedCellsAction): + """QAction to cut text from selected cells in a :class:`QTableWidget` + into the clipboard. + + The text is deleted from the original table widget + (use :class:`CopySelectedCellsAction` to preserve the original data). + + If multiple cells are selected, the cut text will be a concatenation + of the texts in all selected cells, tabulated with tabulation and + newline characters. + + If the cells are sparsely selected, the structure is preserved by + representing the unselected cells as empty strings in between two + tabulation characters. + Beware of pasting this data in another table widget, because depending + on how the paste is implemented, the empty cells may cause data in the + target table to be deleted, even though you didn't necessarily select the + corresponding cell in the origin table. + + :param table: :class:`QTableView` to which this action belongs.""" + def __init__(self, table): + super(CutSelectedCellsAction, self).__init__(table) + self.setText("Cut selection") + self.setShortcut(qt.QKeySequence.Cut) + self.setShortcutContext(qt.Qt.WidgetShortcut) + # cutting is already implemented in CopySelectedCellsAction (but + # it is disabled), we just need to enable it + self.cut = True + + +class CutAllCellsAction(CopyAllCellsAction): + """QAction to cut text from all cells in a :class:`QTableWidget` + into the clipboard. + + The text is deleted from the original table widget + (use :class:`CopyAllCellsAction` to preserve the original data). + + The cut text will be a concatenation + of the texts in all cells, tabulated with tabulation and + newline characters. + + :param table: :class:`QTableView` to which this action belongs.""" + def __init__(self, table): + super(CutAllCellsAction, self).__init__(table) + self.setText("Cut all") + self.setToolTip("Cut all cells into the clipboard.") + self.cut = True + + +def _parseTextAsTable(text, row_separator=row_separator, col_separator=col_separator): + """Parse text into list of lists (2D sequence). + + The input text must be tabulated using tabulation characters and + newlines to separate columns and rows. + + :param text: text to be parsed + :param record_separator: String, or single character, to be interpreted + as a record/row separator. + :param field_separator: String, or single character, to be interpreted + as a field/column separator. + :return: 2DÂ sequence of strings + """ + rows = text.split(row_separator) + table_data = [row.split(col_separator) for row in rows] + return table_data + + +class PasteCellsAction(qt.QAction): + """QAction to paste text from the clipboard into the table. + + If the text contains tabulations and + newlines, they are interpreted as column and row separators. + In such a case, the text is split into multiple texts to be pasted + into multiple cells. + + If a cell content is an empty string in the original text, it is + ignored: the destination cell's text will not be deleted. + + :param table: :class:`QTableView` to which this action belongs. + """ + def __init__(self, table): + if not isinstance(table, qt.QTableView): + raise ValueError('PasteCellsAction must be initialised ' + + 'with a QTableWidget.') + super(PasteCellsAction, self).__init__(table) + self.table = table + self.setText("Paste") + self.setShortcut(qt.QKeySequence.Paste) + self.setShortcutContext(qt.Qt.WidgetShortcut) + self.setToolTip("Paste data. The selected cell is the top-left" + + "corner of the paste area.") + self.triggered.connect(self.pasteCellFromClipboard) + + def pasteCellFromClipboard(self): + """Paste text from clipboard into the table. + + :return: *True* in case of success, *False* if pasting data failed. + """ + selected_idx = self.table.selectedIndexes() + if len(selected_idx) != 1: + msgBox = qt.QMessageBox(parent=self.table) + msgBox.setText("A single cell must be selected to paste data") + msgBox.exec_() + return False + + data_model = self.table.model() + + selected_row = selected_idx[0].row() + selected_col = selected_idx[0].column() + + qapp = qt.QApplication.instance() + clipboard_text = qapp.clipboard().text() + table_data = _parseTextAsTable(clipboard_text) + + protected_cells = 0 + out_of_range_cells = 0 + + # paste table data into cells, using selected cell as origin + for row_offset in range(len(table_data)): + for col_offset in range(len(table_data[row_offset])): + target_row = selected_row + row_offset + target_col = selected_col + col_offset + + if target_row >= data_model.rowCount() or\ + target_col >= data_model.columnCount(): + out_of_range_cells += 1 + continue + + index = data_model.index(target_row, target_col) + flags = data_model.flags(index) + + # ignore empty strings + if table_data[row_offset][col_offset] != "": + if not flags & qt.Qt.ItemIsEditable: + protected_cells += 1 + continue + data_model.setData(index, table_data[row_offset][col_offset]) + # item.setText(table_data[row_offset][col_offset]) + + if protected_cells or out_of_range_cells: + msgBox = qt.QMessageBox(parent=self.table) + msg = "Some data could not be inserted, " + msg += "due to out-of-range or write-protected cells." + msgBox.setText(msg) + msgBox.exec_() + return False + return True + + +class TableWidget(qt.QTableWidget): + """:class:`QTableWidget` with a context menu displaying up to 5 actions: + + - :class:`CopySelectedCellsAction` + - :class:`CopyAllCellsAction` + - :class:`CutSelectedCellsAction` + - :class:`CutAllCellsAction` + - :class:`PasteCellsAction` + + These actions interact with the clipboard and can be used to copy data + to or from an external application, or another widget. + + The cut and paste actions are disabled by default, due to the risk of + overwriting data (no *Undo* action is available). Use :meth:`enablePaste` + and :meth:`enableCut` to activate them. + + :param parent: Parent QWidget + :param bool cut: Enable cut action + :param bool paste: Enable paste action + """ + def __init__(self, parent=None, cut=False, paste=False): + super(TableWidget, self).__init__(parent) + self.addAction(CopySelectedCellsAction(self)) + self.addAction(CopyAllCellsAction(self)) + if cut: + self.enableCut() + if paste: + self.enablePaste() + + self.setContextMenuPolicy(qt.Qt.ActionsContextMenu) + + def enablePaste(self): + """Enable paste action, to paste data from the clipboard into the + table. + + .. warning:: + + This action can cause data to be overwritten. + There is currently no *Undo* action to retrieve lost data. + """ + self.addAction(PasteCellsAction(self)) + + def enableCut(self): + """Enable cut action. + + .. warning:: + + This action can cause data to be deleted. + There is currently no *Undo* action to retrieve lost data.""" + self.addAction(CutSelectedCellsAction(self)) + self.addAction(CutAllCellsAction(self)) + + +class TableView(qt.QTableView): + """:class:`QTableView` with a context menu displaying up to 5 actions: + + - :class:`CopySelectedCellsAction` + - :class:`CopyAllCellsAction` + - :class:`CutSelectedCellsAction` + - :class:`CutAllCellsAction` + - :class:`PasteCellsAction` + + These actions interact with the clipboard and can be used to copy data + to or from an external application, or another widget. + + The cut and paste actions are disabled by default, due to the risk of + overwriting data (no *Undo* action is available). Use :meth:`enablePaste` + and :meth:`enableCut` to activate them. + + .. note:: + + These actions will be available only after a model is associated + with this view, using :meth:`setModel`. + + :param parent: Parent QWidget + :param bool cut: Enable cut action + :param bool paste: Enable paste action + """ + def __init__(self, parent=None, cut=False, paste=False): + super(TableView, self).__init__(parent) + self.cut = cut + self.paste = paste + + def setModel(self, model): + """Set the data model for the table view, activate the actions + and the context menu. + + :param model: :class:`qt.QAbstractItemModel` object + """ + super(TableView, self).setModel(model) + + self.addAction(CopySelectedCellsAction(self)) + self.addAction(CopyAllCellsAction(self)) + if self.cut: + self.enableCut() + if self.paste: + self.enablePaste() + + self.setContextMenuPolicy(qt.Qt.ActionsContextMenu) + + def enablePaste(self): + """Enable paste action, to paste data from the clipboard into the + table. + + .. warning:: + + This action can cause data to be overwritten. + There is currently no *Undo* action to retrieve lost data. + """ + self.addAction(PasteCellsAction(self)) + + def enableCut(self): + """Enable cut action. + + .. warning:: + + This action can cause data to be deleted. + There is currently no *Undo* action to retrieve lost data. + """ + self.addAction(CutSelectedCellsAction(self)) + self.addAction(CutAllCellsAction(self)) + + def addAction(self, action): + # ensure the actions are not added multiple times: + # compare action type and parent widget with those of existing actions + for existing_action in self.actions(): + if type(action) == type(existing_action): + if hasattr(action, "table") and\ + action.table is existing_action.table: + return None + super(TableView, self).addAction(action) + +if __name__ == "__main__": + app = qt.QApplication([]) + + tablewidget = TableWidget() + tablewidget.setWindowTitle("TableWidget") + tablewidget.setColumnCount(10) + tablewidget.setRowCount(7) + tablewidget.enableCut() + tablewidget.enablePaste() + tablewidget.show() + + tableview = TableView(cut=True, paste=True) + tableview.setWindowTitle("TableView") + model = qt.QStandardItemModel() + model.setColumnCount(10) + model.setRowCount(7) + tableview.setModel(model) + tableview.show() + + app.exec_() diff --git a/silx/gui/widgets/ThreadPoolPushButton.py b/silx/gui/widgets/ThreadPoolPushButton.py new file mode 100644 index 0000000..29e831d --- /dev/null +++ b/silx/gui/widgets/ThreadPoolPushButton.py @@ -0,0 +1,233 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""ThreadPoolPushButton module +""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "13/10/2016" + +import logging +from .. import qt +from .WaitingPushButton import WaitingPushButton + + +_logger = logging.getLogger(__name__) + + +class _Wrapper(qt.QRunnable): + """Wrapper to allow to call a function into a `QThreadPool` and + sending signals during the life cycle of the object""" + + def __init__(self, signalHolder, function, args, kwargs): + """Constructor""" + super(_Wrapper, self).__init__() + self.__signalHolder = signalHolder + self.__callable = function + self.__args = args + self.__kwargs = kwargs + + def run(self): + holder = self.__signalHolder + holder.started.emit() + try: + result = self.__callable(*self.__args, **self.__kwargs) + holder.succeeded.emit(result) + except Exception as e: + module = self.__callable.__module__ + name = self.__callable.__name__ + _logger.error("Error while executing callable %s.%s.", module, name, exc_info=True) + holder.failed.emit(e) + finally: + holder.finished.emit() + holder._sigReleaseRunner.emit(self) + + def autoDelete(self): + """Returns true to ask the QThreadPool to manage the life cycle of + this QRunner.""" + return True + + +class ThreadPoolPushButton(WaitingPushButton): + """ + ThreadPoolPushButton provides a simple push button to execute + a threaded task with user feedback when the task is running. + + The task can be defined with the method `setCallable`. It takes a python + function and arguments as parameters. + + WARNING: This task is run in a separate thread. + + Everytime the button is pushed a new runner is created to execute the + function with defined arguments. An animated waiting icon is displayed + to show the activity. By default the button is disabled when an execution + is requested. This behaviour can be disabled by using + `setDisabledWhenWaiting`. + + When the button is clicked a `beforeExecuting` signal is sent from the + Qt main thread. Then the task is started in a thread pool and the following + signals are emitted from the thread pool. Right before calling the + registered callable, the widget emits a `started` signal. + When the task ends, its result is emitted by the `succeeded` signal, but + if it fails the signal `failed` is emitted with the resulting exception. + At the end, the `finished` signal is emitted. + + The task can be programatically executed by using `executeCallable`. + + >>> # Compute a value + >>> import math + >>> button = ThreadPoolPushButton(text="Compute 2^16") + >>> button.setCallable(math.pow, 2, 16) + >>> button.succeeded.connect(print) # python3 + + >>> # Compute a wrong value + >>> import math + >>> button = ThreadPoolPushButton(text="Compute sqrt(-1)") + >>> button.setCallable(math.sqrt, -1) + >>> button.failed.connect(print) # python3 + """ + + def __init__(self, parent=None, text=None, icon=None): + """Constructor + + :param str text: Text displayed on the button + :param qt.QIcon icon: Icon displayed on the button + :param qt.QWidget parent: Parent of the widget + """ + WaitingPushButton.__init__(self, parent=parent, text=text, icon=icon) + self.__callable = None + self.__args = None + self.__kwargs = None + self.__runnerCount = 0 + self.__runnerSet = set([]) + self.clicked.connect(self.executeCallable) + self.finished.connect(self.__runnerFinished) + self._sigReleaseRunner.connect(self.__releaseRunner) + + beforeExecuting = qt.Signal() + """Signal emitted just before execution of the callable by the main Qt + thread. In synchronous mode (direct mode), it can be used to define + dynamically `setCallable`, or to execute something in the Qt thread before + the execution, or both.""" + + started = qt.Signal() + """Signal emitted from the thread pool when the defined callable is + started. + + WARNING: This signal is emitted from the thread performing the task, and + might be received after the registered callable has been called. If you + want to perform some initialisation or set the callable to run, use the + `beforeExecuting` signal instead. + """ + + finished = qt.Signal() + """Signal emitted from the thread pool when the defined callable is + finished""" + + succeeded = qt.Signal(object) + """Signal emitted from the thread pool when the callable exit with a + success. + + The parameter of the signal is the result returned by the callable. + """ + + failed = qt.Signal(object) + """Signal emitted emitted from the thread pool when the callable raises an + exception. + + The parameter of the signal is the raised exception. + """ + + _sigReleaseRunner = qt.Signal(object) + """Callback to release runners""" + + def __runnerStarted(self): + """Called when a runner is started. + + Count the number of executed tasks to change the state of the widget. + """ + self.__runnerCount += 1 + if self.__runnerCount > 0: + self.wait() + + def __runnerFinished(self): + """Called when a runner is finished. + + Count the number of executed tasks to change the state of the widget. + """ + self.__runnerCount -= 1 + if self.__runnerCount <= 0: + self.stopWaiting() + + @qt.Slot() + def executeCallable(self): + """Execute the defined callable in QThreadPool. + + First emit a `beforeExecuting` signal. + If callable is not defined, nothing append. + If a callable is defined, it will be started + as a new thread using the `QThreadPool` system. At start of the thread + the `started` will be emitted. When the callable returns a result it + is emitted by the `succeeded` signal. If the callable fail, the signal + `failed` is emitted with the resulting exception. Then the `finished` + signal is emitted. + """ + self.beforeExecuting.emit() + if self.__callable is None: + return + self.__runnerStarted() + runner = self._createRunner(self.__callable, self.__args, self.__kwargs) + qt.QThreadPool.globalInstance().start(runner) + self.__runnerSet.add(runner) + + def __releaseRunner(self, runner): + self.__runnerSet.remove(runner) + + def _createRunner(self, function, args, kwargs): + """Create a QRunnable from a callable object. + + :param callable function: A callable Python object. + :param list args: List of arguments to call the function. + :param dict kwargs: Dictionary of arguments used to call the function. + :rtpye: qt.QRunnable + """ + runnable = _Wrapper(self, function, args, kwargs) + return runnable + + def setCallable(self, function, *args, **kwargs): + """Define a callable which will be executed on QThreadPool everytime + the button is clicked. + + To retrieve the results, connect to the `succeeded` signal. + + WARNING: The callable will be called in a separate thread. + + :param callable function: A callable Python object + :param list args: List of arguments to call the function. + :param dict kwargs: Dictionary of arguments used to call the function. + """ + self.__callable = function + self.__args = args + self.__kwargs = kwargs diff --git a/silx/gui/widgets/WaitingPushButton.py b/silx/gui/widgets/WaitingPushButton.py new file mode 100644 index 0000000..49ab9b9 --- /dev/null +++ b/silx/gui/widgets/WaitingPushButton.py @@ -0,0 +1,243 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""WaitingPushButton module +""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "26/04/2017" + +from .. import qt +from .. import icons + + +class WaitingPushButton(qt.QPushButton): + """Button which allows to display a waiting status when, for example, + something is still computing. + + The component is graphically disabled when it is in waiting. Then we + overwrite the enabled method to dissociate the 2 concepts: + graphically enabled/disabled, and enabled/disabled + """ + + def __init__(self, parent=None, text=None, icon=None): + """Constructor + + :param str text: Text displayed on the button + :param qt.QIcon icon: Icon displayed on the button + :param qt.QWidget parent: Parent of the widget + """ + if icon is not None: + qt.QPushButton.__init__(self, icon, text, parent) + elif text is not None: + qt.QPushButton.__init__(self, text, parent) + else: + qt.QPushButton.__init__(self, parent) + + self.__waiting = False + self.__enabled = True + self.__icon = icon + self.__disabled_when_waiting = True + self.__waitingIcon = icons.getWaitIcon() + + def sizeHint(self): + """Returns the recommended size for the widget. + + This implementation of the recommended size always consider there is an + icon. In this way it avoid to update the layout when the waiting icon + is displayed. + """ + self.ensurePolished() + + w = 0 + h = 0 + + opt = qt.QStyleOptionButton() + self.initStyleOption(opt) + + # Content with icon + # no condition, assume that there is an icon to avoid blinking + # when the widget switch to waiting state + ih = opt.iconSize.height() + iw = opt.iconSize.width() + 4 + w += iw + h = max(h, ih) + + # Content with text + text = self.text() + isEmpty = text == "" + if isEmpty: + text = "XXXX" + fm = self.fontMetrics() + textSize = fm.size(qt.Qt.TextShowMnemonic, text) + if not isEmpty or w == 0: + w += textSize.width() + if not isEmpty or h == 0: + h = max(h, textSize.height()) + + # Content with menu indicator + opt.rect.setSize(qt.QSize(w, h)) # PM_MenuButtonIndicator depends on the height + if self.menu() is not None: + w += self.style().pixelMetric(qt.QStyle.PM_MenuButtonIndicator, opt, self) + + contentSize = qt.QSize(w, h) + if qt.qVersion().startswith("4.8."): + # On PyQt4/PySide the method QCommonStyle sizeFromContents returns + # different size when the widget provides an icon or not. + # In Qt5 there is not this problem. + opt.icon = qt.QIcon() + sizeHint = self.style().sizeFromContents(qt.QStyle.CT_PushButton, opt, contentSize, self) + sizeHint = sizeHint.expandedTo(qt.QApplication.globalStrut()) + return sizeHint + + def setDisabledWhenWaiting(self, isDisabled): + """Enable or disable the auto disable behaviour when the button is waiting. + + :param bool isDisabled: Enable the auto-disable behaviour + """ + if self.__disabled_when_waiting == isDisabled: + return + self.__disabled_when_waiting = isDisabled + self.__updateVisibleEnabled() + + def isDisabledWhenWaiting(self): + """Returns true if the button is auto disabled when it is waiting. + + :rtype: bool + """ + return self.__disabled_when_waiting + + disabledWhenWaiting = qt.Property(bool, isDisabledWhenWaiting, setDisabledWhenWaiting) + """Property to enable/disable the auto disabled state when the button is waiting.""" + + def __setWaitingIcon(self, icon): + """Called when the waiting icon is updated. It is called every frames + of the animation. + + :param qt.QIcon icon: The new waiting icon + """ + qt.QPushButton.setIcon(self, icon) + + def setIcon(self, icon): + """Set the button icon. If the button is waiting, the icon is not + visible directly, but will be visible when the waiting state will be + removed. + + :param qt.QIcon icon: An icon + """ + self.__icon = icon + self.__updateVisibleIcon() + + def getIcon(self): + """Returns the icon set to the button. If the widget is waiting + it is not returning the visible icon, but the one requested by + the application (the one displayed when the widget is not in + waiting state). + + :rtype: qt.QIcon + """ + return self.__icon + + icon = qt.Property(qt.QIcon, getIcon, setIcon) + """Property providing access to the icon.""" + + def __updateVisibleIcon(self): + """Update the visible icon according to the state of the widget.""" + if not self.isWaiting(): + icon = self.__icon + else: + icon = self.__waitingIcon.currentIcon() + if icon is None: + icon = qt.QIcon() + qt.QPushButton.setIcon(self, icon) + + def setEnabled(self, enabled): + """Set the enabled state of the widget. + + :param bool enabled: The enabled state + """ + if self.__enabled == enabled: + return + self.__enabled = enabled + self.__updateVisibleEnabled() + + def isEnabled(self): + """Returns the enabled state of the widget. + + :rtype: bool + """ + return self.__enabled + + enabled = qt.Property(bool, isEnabled, setEnabled) + """Property providing access to the enabled state of the widget""" + + def __updateVisibleEnabled(self): + """Update the visible enabled state according to the state of the + widget.""" + if self.__disabled_when_waiting: + enabled = not self.isWaiting() and self.__enabled + else: + enabled = self.__enabled + qt.QPushButton.setEnabled(self, enabled) + + def setWaiting(self, waiting): + """Set the waiting state of the widget. + + :param bool waiting: Requested state""" + if self.__waiting == waiting: + return + self.__waiting = waiting + + if self.__waiting: + self.__waitingIcon.register(self) + self.__waitingIcon.iconChanged.connect(self.__setWaitingIcon) + else: + # unregister only if the object is registred + self.__waitingIcon.unregister(self) + self.__waitingIcon.iconChanged.disconnect(self.__setWaitingIcon) + + self.__updateVisibleEnabled() + self.__updateVisibleIcon() + + def isWaiting(self): + """Returns true if the widget is in waiting state. + + :rtype: bool""" + return self.__waiting + + @qt.Slot() + def wait(self): + """Enable the waiting state.""" + self.setWaiting(True) + + @qt.Slot() + def stopWaiting(self): + """Disable the waiting state.""" + self.setWaiting(False) + + @qt.Slot() + def swapWaiting(self): + """Swap the waiting state.""" + self.setWaiting(not self.isWaiting()) diff --git a/silx/gui/widgets/__init__.py b/silx/gui/widgets/__init__.py new file mode 100644 index 0000000..034f4d3 --- /dev/null +++ b/silx/gui/widgets/__init__.py @@ -0,0 +1,27 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This package provides a few simple Qt widgets that rely only on a Qt wrapper +for Python (PyQt5, PyQt4 or PySide). No other optional dependencies of *silx* +should be required.""" diff --git a/silx/gui/widgets/setup.py b/silx/gui/widgets/setup.py new file mode 100644 index 0000000..e96ac8d --- /dev/null +++ b/silx/gui/widgets/setup.py @@ -0,0 +1,41 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "11/10/2016" + + +from numpy.distutils.misc_util import Configuration + + +def configuration(parent_package='', top_path=None): + config = Configuration('widgets', parent_package, top_path) + config.add_subpackage('test') + return config + + +if __name__ == "__main__": + from numpy.distutils.core import setup + setup(configuration=configuration) diff --git a/silx/gui/widgets/test/__init__.py b/silx/gui/widgets/test/__init__.py new file mode 100644 index 0000000..afa0f78 --- /dev/null +++ b/silx/gui/widgets/test/__init__.py @@ -0,0 +1,45 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +import unittest + +from . import test_periodictable +from . import test_tablewidget +from . import test_threadpoolpushbutton +from . import test_hierarchicaltableview + +__authors__ = ["V. Valls", "P. Knobel"] +__license__ = "MIT" +__date__ = "07/04/2017" + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTests( + [test_threadpoolpushbutton.suite(), + test_tablewidget.suite(), + test_periodictable.suite(), + test_hierarchicaltableview.suite(), + ]) + return test_suite diff --git a/silx/gui/widgets/test/test_hierarchicaltableview.py b/silx/gui/widgets/test/test_hierarchicaltableview.py new file mode 100644 index 0000000..b3d37ed --- /dev/null +++ b/silx/gui/widgets/test/test_hierarchicaltableview.py @@ -0,0 +1,117 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "07/04/2017" + +import unittest + +from .. import HierarchicalTableView +from ...test.utils import TestCaseQt +from silx.gui import qt + + +class TableModel(HierarchicalTableView.HierarchicalTableModel): + + def __init__(self, parent): + HierarchicalTableView.HierarchicalTableModel.__init__(self, parent) + self.__content = {} + + def rowCount(self, parent=qt.QModelIndex()): + return 3 + + def columnCount(self, parent=qt.QModelIndex()): + return 3 + + def setData1(self): + if qt.qVersion() > "4.6": + self.beginResetModel() + else: + self.reset() + + content = {} + content[0, 0] = ("title", True, (1, 3)) + content[0, 1] = ("a", True, (2, 1)) + content[1, 1] = ("b", False, (1, 2)) + content[1, 2] = ("c", False, (1, 1)) + content[2, 2] = ("d", False, (1, 1)) + self.__content = content + if qt.qVersion() > "4.6": + self.endResetModel() + + def data(self, index, role=qt.Qt.DisplayRole): + if not index.isValid(): + return None + cell = self.__content.get((index.column(), index.row()), None) + if cell is None: + return None + + if role == self.SpanRole: + return cell[2] + elif role == self.IsHeaderRole: + return cell[1] + elif role == qt.Qt.DisplayRole: + return cell[0] + return None + + +class TestHierarchicalTableView(TestCaseQt): + """Test for HierarchicalTableView""" + + def testEmpty(self): + widget = HierarchicalTableView.HierarchicalTableView() + widget.show() + self.qWaitForWindowExposed(widget) + + def testModel(self): + widget = HierarchicalTableView.HierarchicalTableView() + model = TableModel(widget) + # set the data before using the model into the widget + model.setData1() + widget.setModel(model) + span = widget.rowSpan(0, 0), widget.columnSpan(0, 0) + self.assertEqual(span, (1, 3)) + widget.show() + self.qWaitForWindowExposed(widget) + + def testModelUpdate(self): + widget = HierarchicalTableView.HierarchicalTableView() + model = TableModel(widget) + widget.setModel(model) + # set the data after using the model into the widget + model.setData1() + span = widget.rowSpan(0, 0), widget.columnSpan(0, 0) + self.assertEqual(span, (1, 3)) + + +def suite(): + loader = unittest.defaultTestLoader.loadTestsFromTestCase + test_suite = unittest.TestSuite() + test_suite.addTest(loader(TestHierarchicalTableView)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/widgets/test/test_periodictable.py b/silx/gui/widgets/test/test_periodictable.py new file mode 100644 index 0000000..c6bed81 --- /dev/null +++ b/silx/gui/widgets/test/test_periodictable.py @@ -0,0 +1,163 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "05/12/2016" + +import unittest + +from .. import PeriodicTable +from ...test.utils import TestCaseQt +from silx.gui import qt + + +class TestPeriodicTable(TestCaseQt): + """Basic test for ArrayTableWidget with a numpy array""" + + def testShow(self): + """basic test (instantiation done in setUp)""" + pt = PeriodicTable.PeriodicTable() + pt.show() + self.qWaitForWindowExposed(pt) + + def testSelectable(self): + """basic test (instantiation done in setUp)""" + pt = PeriodicTable.PeriodicTable(selectable=True) + self.assertTrue(pt.selectable) + + def testCustomElements(self): + PTI = PeriodicTable.ColoredPeriodicTableItem + my_items = [ + PTI("Xx", 42, 43, 44, "xaxatorium", 1002.2, + bgcolor="#FF0000"), + PTI("Yy", 25, 22, 44, "yoyotrium", 8.8) + ] + + pt = PeriodicTable.PeriodicTable(elements=my_items) + + pt.setSelection(["He", "Xx"]) + selection = pt.getSelection() + self.assertEqual(len(selection), 1) # "He" not found + self.assertEqual(selection[0].symbol, "Xx") + self.assertEqual(selection[0].Z, 42) + self.assertEqual(selection[0].col, 43) + self.assertAlmostEqual(selection[0].mass, 1002.2) + self.assertEqual(qt.QColor(selection[0].bgcolor), + qt.QColor(qt.Qt.red)) + + self.assertTrue(pt.isElementSelected("Xx")) + self.assertFalse(pt.isElementSelected("Yy")) + self.assertRaises(KeyError, pt.isElementSelected, "Yx") + + def testVeryCustomElements(self): + class MyPTI(PeriodicTable.PeriodicTableItem): + def __init__(self, *args): + PeriodicTable.PeriodicTableItem.__init__(self, *args[:6]) + self.my_feature = args[6] + + my_items = [ + MyPTI("Xx", 42, 43, 44, "xaxatorium", 1002.2, "spam"), + MyPTI("Yy", 25, 22, 44, "yoyotrium", 8.8, "eggs") + ] + + pt = PeriodicTable.PeriodicTable(elements=my_items) + + pt.setSelection(["Xx", "Yy"]) + selection = pt.getSelection() + self.assertEqual(len(selection), 2) + self.assertEqual(selection[1].symbol, "Yy") + self.assertEqual(selection[1].Z, 25) + self.assertEqual(selection[1].col, 22) + self.assertEqual(selection[1].row, 44) + self.assertAlmostEqual(selection[0].mass, 1002.2) + self.assertAlmostEqual(selection[0].my_feature, "spam") + + +class TestPeriodicCombo(TestCaseQt): + """Basic test for ArrayTableWidget with a numpy array""" + def setUp(self): + super(TestPeriodicCombo, self).setUp() + self.pc = PeriodicTable.PeriodicCombo() + + def tearDown(self): + del self.pc + super(TestPeriodicCombo, self).tearDown() + + def testShow(self): + """basic test (instantiation done in setUp)""" + self.pc.show() + self.qWaitForWindowExposed(self.pc) + + def testSelect(self): + self.pc.setSelection("Sb") + selection = self.pc.getSelection() + self.assertIsInstance(selection, + PeriodicTable.PeriodicTableItem) + self.assertEqual(selection.symbol, "Sb") + self.assertEqual(selection.Z, 51) + self.assertEqual(selection.name, "antimony") + + +class TestPeriodicList(TestCaseQt): + """Basic test for ArrayTableWidget with a numpy array""" + def setUp(self): + super(TestPeriodicList, self).setUp() + self.pl = PeriodicTable.PeriodicList() + + def tearDown(self): + del self.pl + super(TestPeriodicList, self).tearDown() + + def testShow(self): + """basic test (instantiation done in setUp)""" + self.pl.show() + self.qWaitForWindowExposed(self.pl) + + def testSelect(self): + self.pl.setSelectedElements(["Li", "He", "Au"]) + sel_elmts = self.pl.getSelection() + + self.assertEqual(len(sel_elmts), 3, + "Wrong number of elements selected") + for e in sel_elmts: + self.assertIsInstance(e, PeriodicTable.PeriodicTableItem) + self.assertIn(e.symbol, ["Li", "He", "Au"]) + self.assertIn(e.Z, [2, 3, 79]) + self.assertIn(e.name, ["lithium", "helium", "gold"]) + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestPeriodicTable)) + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestPeriodicList)) + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestPeriodicCombo)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/widgets/test/test_tablewidget.py b/silx/gui/widgets/test/test_tablewidget.py new file mode 100644 index 0000000..5ad0a06 --- /dev/null +++ b/silx/gui/widgets/test/test_tablewidget.py @@ -0,0 +1,61 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Test TableWidget""" + +__authors__ = ["P. Knobel"] +__license__ = "MIT" +__date__ = "05/12/2016" + + +import unittest +from silx.gui.test.utils import TestCaseQt +from silx.gui.widgets.TableWidget import TableWidget + + +class TestTableWidget(TestCaseQt): + def setUp(self): + super(TestTableWidget, self).setUp() + self._result = [] + + def testShow(self): + table = TableWidget() + table.setColumnCount(10) + table.setRowCount(7) + table.enableCut() + table.enablePaste() + table.show() + table.hide() + self.qapp.processEvents() + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestTableWidget)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/widgets/test/test_threadpoolpushbutton.py b/silx/gui/widgets/test/test_threadpoolpushbutton.py new file mode 100644 index 0000000..126f8f3 --- /dev/null +++ b/silx/gui/widgets/test/test_threadpoolpushbutton.py @@ -0,0 +1,129 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Test for silx.gui.hdf5 module""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "15/12/2016" + + +import unittest +import time +from silx.gui import qt +from silx.gui.test.utils import TestCaseQt +from silx.gui.test.utils import SignalListener +from silx.gui.widgets.ThreadPoolPushButton import ThreadPoolPushButton +from silx.test.utils import TestLogging + + +class TestThreadPoolPushButton(TestCaseQt): + + def setUp(self): + super(TestThreadPoolPushButton, self).setUp() + self._result = [] + + def _trace(self, name, delay=0): + self._result.append(name) + if delay != 0: + time.sleep(delay / 1000.0) + + def _compute(self): + return "result" + + def _computeFail(self): + raise Exception("exception") + + def testExecute(self): + button = ThreadPoolPushButton() + button.setCallable(self._trace, "a", 0) + button.executeCallable() + time.sleep(0.1) + self.assertListEqual(self._result, ["a"]) + self.qapp.processEvents() + + def testMultiExecution(self): + button = ThreadPoolPushButton() + button.setCallable(self._trace, "a", 0) + number = qt.QThreadPool.globalInstance().maxThreadCount() * 2 + for _ in range(number): + button.executeCallable() + time.sleep(number * 0.01 + 0.1) + self.assertListEqual(self._result, ["a"] * number) + self.qapp.processEvents() + + def testSaturateThreadPool(self): + button = ThreadPoolPushButton() + button.setCallable(self._trace, "a", 100) + number = qt.QThreadPool.globalInstance().maxThreadCount() * 2 + for _ in range(number): + button.executeCallable() + time.sleep(number * 0.1 + 0.1) + self.assertListEqual(self._result, ["a"] * number) + self.qapp.processEvents() + + def testSuccess(self): + listener = SignalListener() + button = ThreadPoolPushButton() + button.setCallable(self._compute) + button.beforeExecuting.connect(listener.partial(test="be")) + button.started.connect(listener.partial(test="s")) + button.succeeded.connect(listener.partial(test="result")) + button.failed.connect(listener.partial(test="Unexpected exception")) + button.finished.connect(listener.partial(test="f")) + button.executeCallable() + self.qapp.processEvents() + time.sleep(0.1) + self.qapp.processEvents() + result = listener.karguments(argumentName="test") + self.assertListEqual(result, ["be", "s", "result", "f"]) + + def testFail(self): + listener = SignalListener() + button = ThreadPoolPushButton() + button.setCallable(self._computeFail) + button.beforeExecuting.connect(listener.partial(test="be")) + button.started.connect(listener.partial(test="s")) + button.succeeded.connect(listener.partial(test="Unexpected success")) + button.failed.connect(listener.partial(test="exception")) + button.finished.connect(listener.partial(test="f")) + with TestLogging('silx.gui.widgets.ThreadPoolPushButton', error=1): + button.executeCallable() + self.qapp.processEvents() + time.sleep(0.1) + self.qapp.processEvents() + result = listener.karguments(argumentName="test") + self.assertListEqual(result, ["be", "s", "exception", "f"]) + listener.clear() + + +def suite(): + test_suite = unittest.TestSuite() + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestThreadPoolPushButton)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') |