diff options
Diffstat (limited to 'src/silx/math/test/test_interpolate.py')
-rw-r--r-- | src/silx/math/test/test_interpolate.py | 125 |
1 files changed, 125 insertions, 0 deletions
diff --git a/src/silx/math/test/test_interpolate.py b/src/silx/math/test/test_interpolate.py new file mode 100644 index 0000000..146449d --- /dev/null +++ b/src/silx/math/test/test_interpolate.py @@ -0,0 +1,125 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +"""Test for interpolate module""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "11/07/2019" + + +import unittest + +import numpy +try: + from scipy.interpolate import interpn +except ImportError: + interpn = None + +from silx.utils.testutils import ParametricTestCase +from silx.math import interpolate + + +@unittest.skipUnless(interpn is not None, "scipy missing") +class TestInterp3d(ParametricTestCase): + """Test silx.math.interpolate.interp3d""" + + @staticmethod + def ref_interp3d(data, points): + """Reference implementation of interp3d based on scipy + + :param numpy.ndarray data: 3D floating dataset + :param numpy.ndarray points: Array of points of shape (N, 3) + """ + return interpn( + [numpy.arange(dim, dtype=data.dtype) for dim in data.shape], + data, + points, + method='linear') + + def test_random_data(self): + """Test interp3d with random data""" + size = 32 + npoints = 10 + + ref_data = numpy.random.random((size, size, size)) + ref_points = numpy.random.random(npoints*3).reshape(npoints, 3) * (size -1) + + for dtype in (numpy.float32, numpy.float64): + data = ref_data.astype(dtype) + points = ref_points.astype(dtype) + ref_result = self.ref_interp3d(data, points) + + for method in (u'linear', u'linear_omp'): + with self.subTest(method=method): + result = interpolate.interp3d(data, points, method=method) + self.assertTrue(numpy.allclose(ref_result, result)) + + def test_notfinite_data(self): + """Test interp3d with NaN and inf""" + data = numpy.ones((3, 3, 3), dtype=numpy.float64) + data[0, 0, 0] = numpy.nan + data[2, 2, 2] = numpy.inf + points = numpy.array([(0.5, 0.5, 0.5), + (1.5, 1.5, 1.5)]) + + for method in (u'linear', u'linear_omp'): + with self.subTest(method=method): + result = interpolate.interp3d( + data, points, method=method) + self.assertTrue(numpy.isnan(result[0])) + self.assertTrue(result[1] == numpy.inf) + + def test_points_outside(self): + """Test interp3d with points outside the volume""" + data = numpy.ones((4, 4, 4), dtype=numpy.float64) + points = numpy.array([(-0.1, -0.1, -0.1), + (3.1, 3.1, 3.1), + (-0.1, 1., 1.), + (1., 1., 3.1)]) + + for method in (u'linear', u'linear_omp'): + for fill_value in (numpy.nan, 0., -1.): + with self.subTest(method=method): + result = interpolate.interp3d( + data, points, method=method, fill_value=fill_value) + if numpy.isnan(fill_value): + self.assertTrue(numpy.all(numpy.isnan(result))) + else: + self.assertTrue(numpy.all(numpy.equal(result, fill_value))) + + def test_integer_points(self): + """Test interp3d with integer points coord""" + data = numpy.arange(4**3, dtype=numpy.float64).reshape(4, 4, 4) + points = numpy.array([(0., 0., 0.), + (0., 0., 1.), + (2., 3., 0.), + (3., 3., 3.)]) + + ref_result = data[tuple(points.T.astype(numpy.int32))] + + for method in (u'linear', u'linear_omp'): + with self.subTest(method=method): + result = interpolate.interp3d(data, points, method=method) + self.assertTrue(numpy.allclose(ref_result, result)) |