diff options
Diffstat (limited to 'silx/math/test/test_colormap.py')
-rw-r--r-- | silx/math/test/test_colormap.py | 88 |
1 files changed, 80 insertions, 8 deletions
diff --git a/silx/math/test/test_colormap.py b/silx/math/test/test_colormap.py index cafe537..4e80710 100644 --- a/silx/math/test/test_colormap.py +++ b/silx/math/test/test_colormap.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2018-2019 European Synchrotron Radiation Facility +# Copyright (c) 2018-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -34,7 +34,6 @@ __date__ = "16/05/2018" import logging import sys import unittest -import warnings import numpy @@ -45,10 +44,77 @@ from silx.math import colormap _logger = logging.getLogger(__name__) -class TestColormap(ParametricTestCase): - """Test silx.image.colormap.cmap""" +class TestNormalization(ParametricTestCase): + """Test silx.math.colormap.Normalization sub classes""" + + def _testCodec(self, normalization, rtol=1e-5): + """Test apply/revert for normalizations""" + test_data = (numpy.arange(1, 10, dtype=numpy.int32), + numpy.linspace(1., 100., 1000, dtype=numpy.float32), + numpy.linspace(-1., 1., 100, dtype=numpy.float32), + 1., + 1) + + for index in range(len(test_data)): + with self.subTest(normalization=normalization, data_index=index): + data = test_data[index] + normalized = normalization.apply(data, 1., 100.) + result = normalization.revert(normalized, 1., 100.) + + self.assertTrue(numpy.array_equal( + numpy.isnan(normalized), numpy.isnan(result))) + + if isinstance(data, numpy.ndarray): + notNaN = numpy.logical_not(numpy.isnan(result)) + data = data[notNaN] + result = result[notNaN] + self.assertTrue(numpy.allclose(data, result, rtol=rtol)) + + def testLinearNormalization(self): + """Test for LinearNormalization""" + normalization = colormap.LinearNormalization() + self._testCodec(normalization) + + def testLogarithmicNormalization(self): + """Test for LogarithmicNormalization""" + normalization = colormap.LogarithmicNormalization() + # relative tolerance is higher because of the log approximation + self._testCodec(normalization, rtol=1e-3) + + # Specific extra tests + self.assertTrue(numpy.isnan(normalization.apply(-1., 1., 100.))) + self.assertTrue(numpy.isnan(normalization.apply(numpy.nan, 1., 100.))) + self.assertEqual(normalization.apply(numpy.inf, 1., 100.), numpy.inf) + self.assertEqual(normalization.apply(0, 1., 100.), - numpy.inf) + + def testArcsinhNormalization(self): + """Test for ArcsinhNormalization""" + self._testCodec(colormap.ArcsinhNormalization()) + + def testSqrtNormalization(self): + """Test for SqrtNormalization""" + normalization = colormap.SqrtNormalization() + self._testCodec(normalization) + + # Specific extra tests + self.assertTrue(numpy.isnan(normalization.apply(-1., 0., 100.))) + self.assertTrue(numpy.isnan(normalization.apply(numpy.nan, 0., 100.))) + self.assertEqual(normalization.apply(numpy.inf, 0., 100.), numpy.inf) + self.assertEqual(normalization.apply(0, 0., 100.), 0.) + - NORMALIZATIONS = 'linear', 'log', 'arcsinh', 'sqrt' +class TestColormap(ParametricTestCase): + """Test silx.math.colormap.cmap""" + + NORMALIZATIONS = ( + 'linear', + 'log', + 'arcsinh', + 'sqrt', + colormap.LinearNormalization(), + colormap.LogarithmicNormalization(), + colormap.PowerNormalization(2.), + colormap.PowerNormalization(0.5)) @staticmethod def ref_colormap(data, colors, vmin, vmax, normalization, nan_color): @@ -66,9 +132,13 @@ class TestColormap(ParametricTestCase): 'arcsinh': numpy.arcsinh, 'sqrt': numpy.sqrt} - norm_function = norm_functions[normalization] - with warnings.catch_warnings(): - warnings.simplefilter('ignore', category=RuntimeWarning) + if isinstance(normalization, str): + norm_function = norm_functions[normalization] + else: + def norm_function(value): + return normalization.apply(value, vmin, vmax) + + with numpy.errstate(divide='ignore', invalid='ignore'): # Ignore divide by zero and invalid value encountered in log10, sqrt norm_data, vmin, vmax = map(norm_function, (data, vmin, vmax)) @@ -187,6 +257,8 @@ def suite(): test_suite = unittest.TestSuite() test_suite.addTest( unittest.defaultTestLoader.loadTestsFromTestCase(TestColormap)) + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestNormalization)) return test_suite |