summaryrefslogtreecommitdiff
path: root/silx/math/test/test_colormap.py
diff options
context:
space:
mode:
Diffstat (limited to 'silx/math/test/test_colormap.py')
-rw-r--r--silx/math/test/test_colormap.py88
1 files changed, 80 insertions, 8 deletions
diff --git a/silx/math/test/test_colormap.py b/silx/math/test/test_colormap.py
index cafe537..4e80710 100644
--- a/silx/math/test/test_colormap.py
+++ b/silx/math/test/test_colormap.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -34,7 +34,6 @@ __date__ = "16/05/2018"
import logging
import sys
import unittest
-import warnings
import numpy
@@ -45,10 +44,77 @@ from silx.math import colormap
_logger = logging.getLogger(__name__)
-class TestColormap(ParametricTestCase):
- """Test silx.image.colormap.cmap"""
+class TestNormalization(ParametricTestCase):
+ """Test silx.math.colormap.Normalization sub classes"""
+
+ def _testCodec(self, normalization, rtol=1e-5):
+ """Test apply/revert for normalizations"""
+ test_data = (numpy.arange(1, 10, dtype=numpy.int32),
+ numpy.linspace(1., 100., 1000, dtype=numpy.float32),
+ numpy.linspace(-1., 1., 100, dtype=numpy.float32),
+ 1.,
+ 1)
+
+ for index in range(len(test_data)):
+ with self.subTest(normalization=normalization, data_index=index):
+ data = test_data[index]
+ normalized = normalization.apply(data, 1., 100.)
+ result = normalization.revert(normalized, 1., 100.)
+
+ self.assertTrue(numpy.array_equal(
+ numpy.isnan(normalized), numpy.isnan(result)))
+
+ if isinstance(data, numpy.ndarray):
+ notNaN = numpy.logical_not(numpy.isnan(result))
+ data = data[notNaN]
+ result = result[notNaN]
+ self.assertTrue(numpy.allclose(data, result, rtol=rtol))
+
+ def testLinearNormalization(self):
+ """Test for LinearNormalization"""
+ normalization = colormap.LinearNormalization()
+ self._testCodec(normalization)
+
+ def testLogarithmicNormalization(self):
+ """Test for LogarithmicNormalization"""
+ normalization = colormap.LogarithmicNormalization()
+ # relative tolerance is higher because of the log approximation
+ self._testCodec(normalization, rtol=1e-3)
+
+ # Specific extra tests
+ self.assertTrue(numpy.isnan(normalization.apply(-1., 1., 100.)))
+ self.assertTrue(numpy.isnan(normalization.apply(numpy.nan, 1., 100.)))
+ self.assertEqual(normalization.apply(numpy.inf, 1., 100.), numpy.inf)
+ self.assertEqual(normalization.apply(0, 1., 100.), - numpy.inf)
+
+ def testArcsinhNormalization(self):
+ """Test for ArcsinhNormalization"""
+ self._testCodec(colormap.ArcsinhNormalization())
+
+ def testSqrtNormalization(self):
+ """Test for SqrtNormalization"""
+ normalization = colormap.SqrtNormalization()
+ self._testCodec(normalization)
+
+ # Specific extra tests
+ self.assertTrue(numpy.isnan(normalization.apply(-1., 0., 100.)))
+ self.assertTrue(numpy.isnan(normalization.apply(numpy.nan, 0., 100.)))
+ self.assertEqual(normalization.apply(numpy.inf, 0., 100.), numpy.inf)
+ self.assertEqual(normalization.apply(0, 0., 100.), 0.)
+
- NORMALIZATIONS = 'linear', 'log', 'arcsinh', 'sqrt'
+class TestColormap(ParametricTestCase):
+ """Test silx.math.colormap.cmap"""
+
+ NORMALIZATIONS = (
+ 'linear',
+ 'log',
+ 'arcsinh',
+ 'sqrt',
+ colormap.LinearNormalization(),
+ colormap.LogarithmicNormalization(),
+ colormap.PowerNormalization(2.),
+ colormap.PowerNormalization(0.5))
@staticmethod
def ref_colormap(data, colors, vmin, vmax, normalization, nan_color):
@@ -66,9 +132,13 @@ class TestColormap(ParametricTestCase):
'arcsinh': numpy.arcsinh,
'sqrt': numpy.sqrt}
- norm_function = norm_functions[normalization]
- with warnings.catch_warnings():
- warnings.simplefilter('ignore', category=RuntimeWarning)
+ if isinstance(normalization, str):
+ norm_function = norm_functions[normalization]
+ else:
+ def norm_function(value):
+ return normalization.apply(value, vmin, vmax)
+
+ with numpy.errstate(divide='ignore', invalid='ignore'):
# Ignore divide by zero and invalid value encountered in log10, sqrt
norm_data, vmin, vmax = map(norm_function, (data, vmin, vmax))
@@ -187,6 +257,8 @@ def suite():
test_suite = unittest.TestSuite()
test_suite.addTest(
unittest.defaultTestLoader.loadTestsFromTestCase(TestColormap))
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestNormalization))
return test_suite