summaryrefslogtreecommitdiff
path: root/silx/opencl/sparse.py
diff options
context:
space:
mode:
Diffstat (limited to 'silx/opencl/sparse.py')
-rw-r--r--silx/opencl/sparse.py82
1 files changed, 64 insertions, 18 deletions
diff --git a/silx/opencl/sparse.py b/silx/opencl/sparse.py
index 8bfaea8..514589a 100644
--- a/silx/opencl/sparse.py
+++ b/silx/opencl/sparse.py
@@ -35,6 +35,7 @@ import numpy
import pyopencl.array as parray
from collections import namedtuple
from pyopencl.scan import GenericScanKernel
+from pyopencl.tools import dtype_to_ctype
from .common import pyopencl as cl
from .processing import OpenclProcessing, EventDescription, BufferDescription
mf = cl.mem_flags
@@ -52,19 +53,20 @@ def tuple_to_csrdata(arrs):
-# only float32 arrays are supported for now
class CSR(OpenclProcessing):
kernel_files = ["sparse.cl"]
- def __init__(self, shape, max_nnz=None, ctx=None, devicetype="all",
- platformid=None, deviceid=None, block_size=None, memory=None,
- profile=False):
+ def __init__(self, shape, dtype="f", max_nnz=None, idx_dtype=numpy.int32,
+ ctx=None, devicetype="all", platformid=None, deviceid=None,
+ block_size=None, memory=None, profile=False):
"""
Compute Compressed Sparse Row format of an image (2D matrix).
It is designed to be compatible with scipy.sparse.csr_matrix.
:param shape: tuple
Matrix shape.
+ :param dtype: str or numpy.dtype, optional
+ Numeric data type. By default, sparse matrix data will be float32.
:param max_nnz: int, optional
Maximum number of non-zero elements. By default, the arrays "data"
and "indices" are allocated with prod(shape) elements, but
@@ -80,8 +82,9 @@ class CSR(OpenclProcessing):
OpenclProcessing.__init__(self, ctx=ctx, devicetype=devicetype,
platformid=platformid, deviceid=deviceid,
+ block_size=block_size, memory=memory,
profile=profile)
- self._set_parameters(shape, max_nnz)
+ self._set_parameters(shape, dtype, max_nnz, idx_dtype)
self._allocate_memory()
self._setup_kernels()
@@ -89,22 +92,47 @@ class CSR(OpenclProcessing):
# -------------------------- Initialization --------------------------------
# --------------------------------------------------------------------------
- def _set_parameters(self, shape, max_nnz):
+ def _set_parameters(self, shape, dtype, max_nnz, idx_dtype):
self.shape = shape
self.size = numpy.prod(shape)
- self.indice_dtype = numpy.int32 #
+ self._set_idx_dtype(idx_dtype)
assert len(shape) == 2 #
if max_nnz is None:
self.max_nnz = numpy.prod(shape) # worst case
else:
self.max_nnz = int(max_nnz)
+ self._set_dtype(dtype)
+
+
+ def _set_idx_dtype(self, idx_dtype):
+ idx_dtype = numpy.dtype(idx_dtype)
+ if idx_dtype.kind not in ["i", "u"]:
+ raise ValueError("Not an integer type: %s" % idx_dtype)
+ # scan value type must have size divisible by 4 bytes
+ if idx_dtype.itemsize % 4 != 0:
+ raise ValueError("Due to an internal pyopencl limitation, idx_dtype type must have size divisible by 4 bytes")
+ self.indice_dtype = idx_dtype #
+
+
+ def _set_dtype(self, dtype):
+ self.dtype = numpy.dtype(dtype)
+ if self.dtype.kind == "c":
+ raise ValueError("Complex data is not supported")
+ if self.dtype == numpy.dtype(numpy.float32):
+ self._c_zero_str = "0.0f"
+ elif self.dtype == numpy.dtype(numpy.float64):
+ self._c_zero_str = "0.0"
+ else: # assuming integer
+ self._c_zero_str = "0"
+ self.c_dtype = dtype_to_ctype(self.dtype)
+ self.idx_c_dtype = dtype_to_ctype(self.indice_dtype)
def _allocate_memory(self):
self.is_cpu = (self.device.type == "CPU") # move to OpenclProcessing ?
self.buffers = [
- BufferDescription("array", (self.size,), numpy.float32, mf.READ_ONLY),
- BufferDescription("data", (self.max_nnz,), numpy.float32, mf.READ_WRITE),
+ BufferDescription("array", (self.size,), self.dtype, mf.READ_ONLY),
+ BufferDescription("data", (self.max_nnz,), self.dtype, mf.READ_WRITE),
BufferDescription("indices", (self.max_nnz,), self.indice_dtype, mf.READ_WRITE),
BufferDescription("indptr", (self.shape[0]+1,), self.indice_dtype, mf.READ_WRITE),
]
@@ -124,10 +152,24 @@ class CSR(OpenclProcessing):
def _setup_compaction_kernel(self):
+ kernel_signature = str(
+ "__global %s *data, \
+ __global %s *data_compacted, \
+ __global %s *indices, \
+ __global %s* indptr \
+ """ % (self.c_dtype, self.c_dtype, self.idx_c_dtype, self.idx_c_dtype)
+ )
+ if self.dtype.kind == "f":
+ map_nonzero_expr = "(fabs(data[i]) > %s) ? 1 : 0" % self._c_zero_str
+ elif self.dtype.kind in ["u", "i"]:
+ map_nonzero_expr = "(data[i] != %s) ? 1 : 0" % self._c_zero_str
+ else:
+ raise ValueError("Unknown data type")
+
self.scan_kernel = GenericScanKernel(
self.ctx, self.indice_dtype,
- arguments="__global float* data, __global float *data_compacted, __global int *indices, __global int* indptr",
- input_expr="(fabs(data[i]) > 0.0f) ? 1 : 0",
+ arguments=kernel_signature,
+ input_expr=map_nonzero_expr,
scan_expr="a+b", neutral="0",
output_statement="""
// item is the running sum of input_expr(i), i.e the cumsum of "nonzero"
@@ -140,7 +182,7 @@ class CSR(OpenclProcessing):
indptr[(i/IMAGE_WIDTH)+1] = item;
}
""",
- options="-DIMAGE_WIDTH=%d" % self.shape[1],
+ options=["-DIMAGE_WIDTH=%d" % self.shape[1]],
preamble="#define GET_INDEX(i) (i % IMAGE_WIDTH)",
)
@@ -149,7 +191,11 @@ class CSR(OpenclProcessing):
OpenclProcessing.compile_kernels(
self,
self.kernel_files,
- compile_options=["-DIMAGE_WIDTH=%d" % self.shape[1]]
+ compile_options=[
+ "-DIMAGE_WIDTH=%d" % self.shape[1],
+ "-DDTYPE=%s" % self.c_dtype,
+ "-DIDX_DTYPE=%s" % self.idx_c_dtype,
+ ]
)
device = self.ctx.devices[0]
wg_x = min(
@@ -174,7 +220,7 @@ class CSR(OpenclProcessing):
2D array in dense format.
"""
assert arr.size == self.size
- assert arr.dtype == numpy.float32
+ assert arr.dtype == self.dtype
# TODO handle pyopencl Buffer
@@ -189,10 +235,10 @@ class CSR(OpenclProcessing):
assert isinstance(csr_data, CSRData)
for arr in [csr_data.data, csr_data.indices, csr_data.indptr]:
assert arr.ndim == 1
- assert csr_data.data.size == self.max_nnz
- assert csr_data.indices.size == self.max_nnz
+ assert csr_data.data.size <= self.max_nnz
+ assert csr_data.indices.size <= self.max_nnz
assert csr_data.indptr.size == self.shape[0]+1
- assert csr_data.data.dtype == numpy.float32
+ assert csr_data.data.dtype == self.dtype
assert csr_data.indices.dtype == self.indice_dtype
assert csr_data.indptr.dtype == self.indice_dtype
@@ -228,7 +274,7 @@ class CSR(OpenclProcessing):
setattr(self, name, arr)
# The current array is a numpy.ndarray: copy H2D
elif isinstance(arr, numpy.ndarray):
- getattr(self, name)[:] = arr[:]
+ getattr(self, name)[:arr.size] = arr[:]
else:
raise ValueError("Unsupported array type: %s" % type(arr))