diff options
Diffstat (limited to 'silx/opencl/test/test_sparse.py')
-rw-r--r-- | silx/opencl/test/test_sparse.py | 85 |
1 files changed, 48 insertions, 37 deletions
diff --git a/silx/opencl/test/test_sparse.py b/silx/opencl/test/test_sparse.py index 56f1ba4..76a6a0a 100644 --- a/silx/opencl/test/test_sparse.py +++ b/silx/opencl/test/test_sparse.py @@ -82,39 +82,45 @@ class TestCSR(unittest.TestCase): """Test CSR format""" def setUp(self): - self.array = generate_sparse_random_data(shape=(512, 511)) - # Compute reference sparsification - a_s = sp.csr_matrix(self.array) - self.ref_data = a_s.data - self.ref_indices = a_s.indices - self.ref_indptr = a_s.indptr - self.ref_nnz = a_s.nnz # Test possible configurations input_on_device = [False, True] output_on_device = [False, True] - self._test_configs = list(product(input_on_device, output_on_device)) + dtypes = [np.float32, np.int32, np.uint16] + self._test_configs = list(product(input_on_device, output_on_device, dtypes)) + + + def compute_ref_sparsification(self, array): + ref_sparse = sp.csr_matrix(array) + return ref_sparse def test_sparsification(self): - for input_on_device, output_on_device in self._test_configs: - self._test_sparsification(input_on_device, output_on_device) + for input_on_device, output_on_device, dtype in self._test_configs: + self._test_sparsification(input_on_device, output_on_device, dtype) - def _test_sparsification(self, input_on_device, output_on_device): - current_config = "input on device: %s, output on device: %s" % ( - str(input_on_device), str(output_on_device) + def _test_sparsification(self, input_on_device, output_on_device, dtype): + current_config = "input on device: %s, output on device: %s, dtype: %s" % ( + str(input_on_device), str(output_on_device), str(dtype) ) + logger.debug("CSR: %s" % current_config) + # Generate data and reference CSR + array = generate_sparse_random_data(shape=(512, 511), dtype=dtype) + ref_sparse = self.compute_ref_sparsification(array) # Sparsify on device - csr = CSR(self.array.shape) + csr = CSR(array.shape, dtype=dtype) if input_on_device: # The array has to be flattened - arr = parray.to_device(csr.queue, self.array.ravel()) + arr = parray.to_device(csr.queue, array.ravel()) else: - arr = self.array + arr = array if output_on_device: - d_data = parray.zeros_like(csr.data) - d_indices = parray.zeros_like(csr.indices) - d_indptr = parray.zeros_like(csr.indptr) + d_data = parray.empty_like(csr.data) + d_indices = parray.empty_like(csr.indices) + d_indptr = parray.empty_like(csr.indptr) + d_data.fill(0) + d_indices.fill(0) + d_indptr.fill(0) output = (d_data, d_indices, d_indptr) else: output = None @@ -124,45 +130,50 @@ class TestCSR(unittest.TestCase): indices = indices.get() indptr = indptr.get() # Compare - nnz = self.ref_nnz + nnz = ref_sparse.nnz self.assertTrue( - np.allclose(data[:nnz], self.ref_data), + np.allclose(data[:nnz], ref_sparse.data), "something wrong with sparsified data (%s)" % current_config ) self.assertTrue( - np.allclose(indices[:nnz], self.ref_indices), + np.allclose(indices[:nnz], ref_sparse.indices), "something wrong with sparsified indices (%s)" % current_config ) self.assertTrue( - np.allclose(indptr, self.ref_indptr), + np.allclose(indptr, ref_sparse.indptr), "something wrong with sparsified indices pointers (indptr) (%s)" % current_config ) def test_desparsification(self): - for input_on_device, output_on_device in self._test_configs: - self._test_desparsification(input_on_device, output_on_device) + for input_on_device, output_on_device, dtype in self._test_configs: + self._test_desparsification(input_on_device, output_on_device, dtype) - def _test_desparsification(self, input_on_device, output_on_device): - current_config = "input on device: %s, output on device: %s" % ( - str(input_on_device), str(output_on_device) + def _test_desparsification(self, input_on_device, output_on_device, dtype): + current_config = "input on device: %s, output on device: %s, dtype: %s" % ( + str(input_on_device), str(output_on_device), str(dtype) ) + logger.debug("CSR: %s" % current_config) + # Generate data and reference CSR + array = generate_sparse_random_data(shape=(512, 511), dtype=dtype) + ref_sparse = self.compute_ref_sparsification(array) # De-sparsify on device - csr = CSR(self.array.shape, max_nnz=self.ref_nnz) + csr = CSR(array.shape, dtype=dtype, max_nnz=ref_sparse.nnz) if input_on_device: - data = parray.to_device(csr.queue, self.ref_data) - indices = parray.to_device(csr.queue, self.ref_indices) - indptr = parray.to_device(csr.queue, self.ref_indptr) + data = parray.to_device(csr.queue, ref_sparse.data) + indices = parray.to_device(csr.queue, ref_sparse.indices) + indptr = parray.to_device(csr.queue, ref_sparse.indptr) else: - data = self.ref_data - indices = self.ref_indices - indptr = self.ref_indptr + data = ref_sparse.data + indices = ref_sparse.indices + indptr = ref_sparse.indptr if output_on_device: - d_arr = parray.zeros_like(csr.array) + d_arr = parray.empty_like(csr.array) + d_arr.fill(0) output = d_arr else: output = None @@ -171,7 +182,7 @@ class TestCSR(unittest.TestCase): arr = arr.get() # Compare self.assertTrue( - np.allclose(arr.reshape(self.array.shape), self.array), + np.allclose(arr.reshape(array.shape), array), "something wrong with densified data (%s)" % current_config ) |