summaryrefslogtreecommitdiff
path: root/silx/opencl/test/test_sparse.py
diff options
context:
space:
mode:
Diffstat (limited to 'silx/opencl/test/test_sparse.py')
-rw-r--r--silx/opencl/test/test_sparse.py85
1 files changed, 48 insertions, 37 deletions
diff --git a/silx/opencl/test/test_sparse.py b/silx/opencl/test/test_sparse.py
index 56f1ba4..76a6a0a 100644
--- a/silx/opencl/test/test_sparse.py
+++ b/silx/opencl/test/test_sparse.py
@@ -82,39 +82,45 @@ class TestCSR(unittest.TestCase):
"""Test CSR format"""
def setUp(self):
- self.array = generate_sparse_random_data(shape=(512, 511))
- # Compute reference sparsification
- a_s = sp.csr_matrix(self.array)
- self.ref_data = a_s.data
- self.ref_indices = a_s.indices
- self.ref_indptr = a_s.indptr
- self.ref_nnz = a_s.nnz
# Test possible configurations
input_on_device = [False, True]
output_on_device = [False, True]
- self._test_configs = list(product(input_on_device, output_on_device))
+ dtypes = [np.float32, np.int32, np.uint16]
+ self._test_configs = list(product(input_on_device, output_on_device, dtypes))
+
+
+ def compute_ref_sparsification(self, array):
+ ref_sparse = sp.csr_matrix(array)
+ return ref_sparse
def test_sparsification(self):
- for input_on_device, output_on_device in self._test_configs:
- self._test_sparsification(input_on_device, output_on_device)
+ for input_on_device, output_on_device, dtype in self._test_configs:
+ self._test_sparsification(input_on_device, output_on_device, dtype)
- def _test_sparsification(self, input_on_device, output_on_device):
- current_config = "input on device: %s, output on device: %s" % (
- str(input_on_device), str(output_on_device)
+ def _test_sparsification(self, input_on_device, output_on_device, dtype):
+ current_config = "input on device: %s, output on device: %s, dtype: %s" % (
+ str(input_on_device), str(output_on_device), str(dtype)
)
+ logger.debug("CSR: %s" % current_config)
+ # Generate data and reference CSR
+ array = generate_sparse_random_data(shape=(512, 511), dtype=dtype)
+ ref_sparse = self.compute_ref_sparsification(array)
# Sparsify on device
- csr = CSR(self.array.shape)
+ csr = CSR(array.shape, dtype=dtype)
if input_on_device:
# The array has to be flattened
- arr = parray.to_device(csr.queue, self.array.ravel())
+ arr = parray.to_device(csr.queue, array.ravel())
else:
- arr = self.array
+ arr = array
if output_on_device:
- d_data = parray.zeros_like(csr.data)
- d_indices = parray.zeros_like(csr.indices)
- d_indptr = parray.zeros_like(csr.indptr)
+ d_data = parray.empty_like(csr.data)
+ d_indices = parray.empty_like(csr.indices)
+ d_indptr = parray.empty_like(csr.indptr)
+ d_data.fill(0)
+ d_indices.fill(0)
+ d_indptr.fill(0)
output = (d_data, d_indices, d_indptr)
else:
output = None
@@ -124,45 +130,50 @@ class TestCSR(unittest.TestCase):
indices = indices.get()
indptr = indptr.get()
# Compare
- nnz = self.ref_nnz
+ nnz = ref_sparse.nnz
self.assertTrue(
- np.allclose(data[:nnz], self.ref_data),
+ np.allclose(data[:nnz], ref_sparse.data),
"something wrong with sparsified data (%s)"
% current_config
)
self.assertTrue(
- np.allclose(indices[:nnz], self.ref_indices),
+ np.allclose(indices[:nnz], ref_sparse.indices),
"something wrong with sparsified indices (%s)"
% current_config
)
self.assertTrue(
- np.allclose(indptr, self.ref_indptr),
+ np.allclose(indptr, ref_sparse.indptr),
"something wrong with sparsified indices pointers (indptr) (%s)"
% current_config
)
def test_desparsification(self):
- for input_on_device, output_on_device in self._test_configs:
- self._test_desparsification(input_on_device, output_on_device)
+ for input_on_device, output_on_device, dtype in self._test_configs:
+ self._test_desparsification(input_on_device, output_on_device, dtype)
- def _test_desparsification(self, input_on_device, output_on_device):
- current_config = "input on device: %s, output on device: %s" % (
- str(input_on_device), str(output_on_device)
+ def _test_desparsification(self, input_on_device, output_on_device, dtype):
+ current_config = "input on device: %s, output on device: %s, dtype: %s" % (
+ str(input_on_device), str(output_on_device), str(dtype)
)
+ logger.debug("CSR: %s" % current_config)
+ # Generate data and reference CSR
+ array = generate_sparse_random_data(shape=(512, 511), dtype=dtype)
+ ref_sparse = self.compute_ref_sparsification(array)
# De-sparsify on device
- csr = CSR(self.array.shape, max_nnz=self.ref_nnz)
+ csr = CSR(array.shape, dtype=dtype, max_nnz=ref_sparse.nnz)
if input_on_device:
- data = parray.to_device(csr.queue, self.ref_data)
- indices = parray.to_device(csr.queue, self.ref_indices)
- indptr = parray.to_device(csr.queue, self.ref_indptr)
+ data = parray.to_device(csr.queue, ref_sparse.data)
+ indices = parray.to_device(csr.queue, ref_sparse.indices)
+ indptr = parray.to_device(csr.queue, ref_sparse.indptr)
else:
- data = self.ref_data
- indices = self.ref_indices
- indptr = self.ref_indptr
+ data = ref_sparse.data
+ indices = ref_sparse.indices
+ indptr = ref_sparse.indptr
if output_on_device:
- d_arr = parray.zeros_like(csr.array)
+ d_arr = parray.empty_like(csr.array)
+ d_arr.fill(0)
output = d_arr
else:
output = None
@@ -171,7 +182,7 @@ class TestCSR(unittest.TestCase):
arr = arr.get()
# Compare
self.assertTrue(
- np.allclose(arr.reshape(self.array.shape), self.array),
+ np.allclose(arr.reshape(array.shape), array),
"something wrong with densified data (%s)"
% current_config
)