diff options
Diffstat (limited to 'silx/opencl/sparse.py')
-rw-r--r-- | silx/opencl/sparse.py | 82 |
1 files changed, 64 insertions, 18 deletions
diff --git a/silx/opencl/sparse.py b/silx/opencl/sparse.py index 8bfaea8..514589a 100644 --- a/silx/opencl/sparse.py +++ b/silx/opencl/sparse.py @@ -35,6 +35,7 @@ import numpy import pyopencl.array as parray from collections import namedtuple from pyopencl.scan import GenericScanKernel +from pyopencl.tools import dtype_to_ctype from .common import pyopencl as cl from .processing import OpenclProcessing, EventDescription, BufferDescription mf = cl.mem_flags @@ -52,19 +53,20 @@ def tuple_to_csrdata(arrs): -# only float32 arrays are supported for now class CSR(OpenclProcessing): kernel_files = ["sparse.cl"] - def __init__(self, shape, max_nnz=None, ctx=None, devicetype="all", - platformid=None, deviceid=None, block_size=None, memory=None, - profile=False): + def __init__(self, shape, dtype="f", max_nnz=None, idx_dtype=numpy.int32, + ctx=None, devicetype="all", platformid=None, deviceid=None, + block_size=None, memory=None, profile=False): """ Compute Compressed Sparse Row format of an image (2D matrix). It is designed to be compatible with scipy.sparse.csr_matrix. :param shape: tuple Matrix shape. + :param dtype: str or numpy.dtype, optional + Numeric data type. By default, sparse matrix data will be float32. :param max_nnz: int, optional Maximum number of non-zero elements. By default, the arrays "data" and "indices" are allocated with prod(shape) elements, but @@ -80,8 +82,9 @@ class CSR(OpenclProcessing): OpenclProcessing.__init__(self, ctx=ctx, devicetype=devicetype, platformid=platformid, deviceid=deviceid, + block_size=block_size, memory=memory, profile=profile) - self._set_parameters(shape, max_nnz) + self._set_parameters(shape, dtype, max_nnz, idx_dtype) self._allocate_memory() self._setup_kernels() @@ -89,22 +92,47 @@ class CSR(OpenclProcessing): # -------------------------- Initialization -------------------------------- # -------------------------------------------------------------------------- - def _set_parameters(self, shape, max_nnz): + def _set_parameters(self, shape, dtype, max_nnz, idx_dtype): self.shape = shape self.size = numpy.prod(shape) - self.indice_dtype = numpy.int32 # + self._set_idx_dtype(idx_dtype) assert len(shape) == 2 # if max_nnz is None: self.max_nnz = numpy.prod(shape) # worst case else: self.max_nnz = int(max_nnz) + self._set_dtype(dtype) + + + def _set_idx_dtype(self, idx_dtype): + idx_dtype = numpy.dtype(idx_dtype) + if idx_dtype.kind not in ["i", "u"]: + raise ValueError("Not an integer type: %s" % idx_dtype) + # scan value type must have size divisible by 4 bytes + if idx_dtype.itemsize % 4 != 0: + raise ValueError("Due to an internal pyopencl limitation, idx_dtype type must have size divisible by 4 bytes") + self.indice_dtype = idx_dtype # + + + def _set_dtype(self, dtype): + self.dtype = numpy.dtype(dtype) + if self.dtype.kind == "c": + raise ValueError("Complex data is not supported") + if self.dtype == numpy.dtype(numpy.float32): + self._c_zero_str = "0.0f" + elif self.dtype == numpy.dtype(numpy.float64): + self._c_zero_str = "0.0" + else: # assuming integer + self._c_zero_str = "0" + self.c_dtype = dtype_to_ctype(self.dtype) + self.idx_c_dtype = dtype_to_ctype(self.indice_dtype) def _allocate_memory(self): self.is_cpu = (self.device.type == "CPU") # move to OpenclProcessing ? self.buffers = [ - BufferDescription("array", (self.size,), numpy.float32, mf.READ_ONLY), - BufferDescription("data", (self.max_nnz,), numpy.float32, mf.READ_WRITE), + BufferDescription("array", (self.size,), self.dtype, mf.READ_ONLY), + BufferDescription("data", (self.max_nnz,), self.dtype, mf.READ_WRITE), BufferDescription("indices", (self.max_nnz,), self.indice_dtype, mf.READ_WRITE), BufferDescription("indptr", (self.shape[0]+1,), self.indice_dtype, mf.READ_WRITE), ] @@ -124,10 +152,24 @@ class CSR(OpenclProcessing): def _setup_compaction_kernel(self): + kernel_signature = str( + "__global %s *data, \ + __global %s *data_compacted, \ + __global %s *indices, \ + __global %s* indptr \ + """ % (self.c_dtype, self.c_dtype, self.idx_c_dtype, self.idx_c_dtype) + ) + if self.dtype.kind == "f": + map_nonzero_expr = "(fabs(data[i]) > %s) ? 1 : 0" % self._c_zero_str + elif self.dtype.kind in ["u", "i"]: + map_nonzero_expr = "(data[i] != %s) ? 1 : 0" % self._c_zero_str + else: + raise ValueError("Unknown data type") + self.scan_kernel = GenericScanKernel( self.ctx, self.indice_dtype, - arguments="__global float* data, __global float *data_compacted, __global int *indices, __global int* indptr", - input_expr="(fabs(data[i]) > 0.0f) ? 1 : 0", + arguments=kernel_signature, + input_expr=map_nonzero_expr, scan_expr="a+b", neutral="0", output_statement=""" // item is the running sum of input_expr(i), i.e the cumsum of "nonzero" @@ -140,7 +182,7 @@ class CSR(OpenclProcessing): indptr[(i/IMAGE_WIDTH)+1] = item; } """, - options="-DIMAGE_WIDTH=%d" % self.shape[1], + options=["-DIMAGE_WIDTH=%d" % self.shape[1]], preamble="#define GET_INDEX(i) (i % IMAGE_WIDTH)", ) @@ -149,7 +191,11 @@ class CSR(OpenclProcessing): OpenclProcessing.compile_kernels( self, self.kernel_files, - compile_options=["-DIMAGE_WIDTH=%d" % self.shape[1]] + compile_options=[ + "-DIMAGE_WIDTH=%d" % self.shape[1], + "-DDTYPE=%s" % self.c_dtype, + "-DIDX_DTYPE=%s" % self.idx_c_dtype, + ] ) device = self.ctx.devices[0] wg_x = min( @@ -174,7 +220,7 @@ class CSR(OpenclProcessing): 2D array in dense format. """ assert arr.size == self.size - assert arr.dtype == numpy.float32 + assert arr.dtype == self.dtype # TODO handle pyopencl Buffer @@ -189,10 +235,10 @@ class CSR(OpenclProcessing): assert isinstance(csr_data, CSRData) for arr in [csr_data.data, csr_data.indices, csr_data.indptr]: assert arr.ndim == 1 - assert csr_data.data.size == self.max_nnz - assert csr_data.indices.size == self.max_nnz + assert csr_data.data.size <= self.max_nnz + assert csr_data.indices.size <= self.max_nnz assert csr_data.indptr.size == self.shape[0]+1 - assert csr_data.data.dtype == numpy.float32 + assert csr_data.data.dtype == self.dtype assert csr_data.indices.dtype == self.indice_dtype assert csr_data.indptr.dtype == self.indice_dtype @@ -228,7 +274,7 @@ class CSR(OpenclProcessing): setattr(self, name, arr) # The current array is a numpy.ndarray: copy H2D elif isinstance(arr, numpy.ndarray): - getattr(self, name)[:] = arr[:] + getattr(self, name)[:arr.size] = arr[:] else: raise ValueError("Unsupported array type: %s" % type(arr)) |