summaryrefslogtreecommitdiff
path: root/src/silx/opencl/test/test_sparse.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/silx/opencl/test/test_sparse.py')
-rw-r--r--src/silx/opencl/test/test_sparse.py188
1 files changed, 188 insertions, 0 deletions
diff --git a/src/silx/opencl/test/test_sparse.py b/src/silx/opencl/test/test_sparse.py
new file mode 100644
index 0000000..1d26b36
--- /dev/null
+++ b/src/silx/opencl/test/test_sparse.py
@@ -0,0 +1,188 @@
+#!/usr/bin/env python
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test of the sparse module"""
+
+import numpy as np
+import unittest
+import logging
+from itertools import product
+from ..common import ocl
+if ocl:
+ import pyopencl.array as parray
+ from silx.opencl.sparse import CSR
+try:
+ import scipy.sparse as sp
+except ImportError:
+ sp = None
+logger = logging.getLogger(__name__)
+
+
+
+def generate_sparse_random_data(
+ shape=(1000,),
+ data_min=0, data_max=100,
+ density=0.1,
+ use_only_integers=True,
+ dtype="f"):
+ """
+ Generate random sparse data where.
+
+ Parameters
+ ------------
+ shape: tuple
+ Output data shape.
+ data_min: int or float
+ Minimum value of data
+ data_max: int or float
+ Maximum value of data
+ density: float
+ Density of non-zero elements in the output data.
+ Low value of density mean low number of non-zero elements.
+ use_only_integers: bool
+ If set to True, the output data items will be primarily integers,
+ possibly casted to float if dtype is a floating-point type.
+ This can be used for ease of debugging.
+ dtype: str or numpy.dtype
+ Output data type
+ """
+ mask = np.random.binomial(1, density, size=shape)
+ if use_only_integers:
+ d = np.random.randint(data_min, high=data_max, size=shape)
+ else:
+ d = data_min + (data_max - data_min) * np.random.rand(*shape)
+ return (d * mask).astype(dtype)
+
+
+
+@unittest.skipUnless(ocl and sp, "PyOpenCl/scipy is missing")
+class TestCSR(unittest.TestCase):
+ """Test CSR format"""
+
+ def setUp(self):
+ # Test possible configurations
+ input_on_device = [False, True]
+ output_on_device = [False, True]
+ dtypes = [np.float32, np.int32, np.uint16]
+ self._test_configs = list(product(input_on_device, output_on_device, dtypes))
+
+
+ def compute_ref_sparsification(self, array):
+ ref_sparse = sp.csr_matrix(array)
+ return ref_sparse
+
+
+ def test_sparsification(self):
+ for input_on_device, output_on_device, dtype in self._test_configs:
+ self._test_sparsification(input_on_device, output_on_device, dtype)
+
+
+ def _test_sparsification(self, input_on_device, output_on_device, dtype):
+ current_config = "input on device: %s, output on device: %s, dtype: %s" % (
+ str(input_on_device), str(output_on_device), str(dtype)
+ )
+ logger.debug("CSR: %s" % current_config)
+ # Generate data and reference CSR
+ array = generate_sparse_random_data(shape=(512, 511), dtype=dtype)
+ ref_sparse = self.compute_ref_sparsification(array)
+ # Sparsify on device
+ csr = CSR(array.shape, dtype=dtype)
+ if input_on_device:
+ # The array has to be flattened
+ arr = parray.to_device(csr.queue, array.ravel())
+ else:
+ arr = array
+ if output_on_device:
+ d_data = parray.empty_like(csr.data)
+ d_indices = parray.empty_like(csr.indices)
+ d_indptr = parray.empty_like(csr.indptr)
+ d_data.fill(0)
+ d_indices.fill(0)
+ d_indptr.fill(0)
+ output = (d_data, d_indices, d_indptr)
+ else:
+ output = None
+ data, indices, indptr = csr.sparsify(arr, output=output)
+ if output_on_device:
+ data = data.get()
+ indices = indices.get()
+ indptr = indptr.get()
+ # Compare
+ nnz = ref_sparse.nnz
+ self.assertTrue(
+ np.allclose(data[:nnz], ref_sparse.data),
+ "something wrong with sparsified data (%s)"
+ % current_config
+ )
+ self.assertTrue(
+ np.allclose(indices[:nnz], ref_sparse.indices),
+ "something wrong with sparsified indices (%s)"
+ % current_config
+ )
+ self.assertTrue(
+ np.allclose(indptr, ref_sparse.indptr),
+ "something wrong with sparsified indices pointers (indptr) (%s)"
+ % current_config
+ )
+
+
+ def test_desparsification(self):
+ for input_on_device, output_on_device, dtype in self._test_configs:
+ self._test_desparsification(input_on_device, output_on_device, dtype)
+
+
+ def _test_desparsification(self, input_on_device, output_on_device, dtype):
+ current_config = "input on device: %s, output on device: %s, dtype: %s" % (
+ str(input_on_device), str(output_on_device), str(dtype)
+ )
+ logger.debug("CSR: %s" % current_config)
+ # Generate data and reference CSR
+ array = generate_sparse_random_data(shape=(512, 511), dtype=dtype)
+ ref_sparse = self.compute_ref_sparsification(array)
+ # De-sparsify on device
+ csr = CSR(array.shape, dtype=dtype, max_nnz=ref_sparse.nnz)
+ if input_on_device:
+ data = parray.to_device(csr.queue, ref_sparse.data)
+ indices = parray.to_device(csr.queue, ref_sparse.indices)
+ indptr = parray.to_device(csr.queue, ref_sparse.indptr)
+ else:
+ data = ref_sparse.data
+ indices = ref_sparse.indices
+ indptr = ref_sparse.indptr
+ if output_on_device:
+ d_arr = parray.empty_like(csr.array)
+ d_arr.fill(0)
+ output = d_arr
+ else:
+ output = None
+ arr = csr.densify(data, indices, indptr, output=output)
+ if output_on_device:
+ arr = arr.get()
+ # Compare
+ self.assertTrue(
+ np.allclose(arr.reshape(array.shape), array),
+ "something wrong with densified data (%s)"
+ % current_config
+ )