summaryrefslogtreecommitdiff
path: root/silx/math/fft/cufft.py
diff options
context:
space:
mode:
Diffstat (limited to 'silx/math/fft/cufft.py')
-rw-r--r--silx/math/fft/cufft.py253
1 files changed, 253 insertions, 0 deletions
diff --git a/silx/math/fft/cufft.py b/silx/math/fft/cufft.py
new file mode 100644
index 0000000..848f3e6
--- /dev/null
+++ b/silx/math/fft/cufft.py
@@ -0,0 +1,253 @@
+#!/usr/bin/env python
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+import numpy as np
+
+from .basefft import BaseFFT
+try:
+ import pycuda.gpuarray as gpuarray
+ from skcuda.fft import Plan
+ from skcuda.fft import fft as cu_fft
+ from skcuda.fft import ifft as cu_ifft
+ __have_cufft__ = True
+except ImportError:
+ __have_cufft__ = False
+
+
+class CUFFT(BaseFFT):
+ """Initialize a cufft plan
+
+ Please see FFT class for parameters help.
+
+ CUFFT-specific parameters
+ --------------------------
+
+ :param pycuda.driver.Stream stream:
+ Stream with which to associate the plan. If no stream is specified,
+ the default stream is used.
+ """
+ def __init__(
+ self,
+ shape=None,
+ dtype=None,
+ template=None,
+ shape_out=None,
+ axes=None,
+ normalize="rescale",
+ stream=None,
+ ):
+ if not(__have_cufft__) or not(__have_cufft__):
+ raise ImportError("Please install pycuda and scikit-cuda to use the CUDA back-end")
+
+ super(CUFFT, self).__init__(
+ shape=shape,
+ dtype=dtype,
+ template=template,
+ shape_out=shape_out,
+ axes=axes,
+ normalize=normalize,
+ )
+ self.cufft_stream = stream
+ self.backend = "cufft"
+
+ self.configure_batched_transform()
+ self.allocate_arrays()
+ self.real_transform = np.isrealobj(self.data_in)
+ self.compute_forward_plan()
+ self.compute_inverse_plan()
+ self.refs = {
+ "data_in": self.data_in,
+ "data_out": self.data_out,
+ }
+ self.configure_normalization()
+
+ def _allocate(self, shape, dtype):
+ return gpuarray.zeros(shape, dtype)
+
+ # TODO support batched transform where batch is other than dimension 0
+ def configure_batched_transform(self):
+ self.cufft_batch_size = 1
+ self.cufft_shape = self.shape
+ if (self.axes is not None) and (len(self.axes) < len(self.shape)):
+ # In the easiest case, the transform is computed along the fastest dimensions:
+ # - 1D transforms of lines of 2D data
+ # - 2D transforms of images of 3D data (stacked along slow dim)
+ # - 1D transforms of 3D data along fastest dim
+ # Otherwise, we have to configure cuda "advanced memory layout",
+ # which is not implemented yet.
+
+ data_ndims = len(self.shape)
+ supported_axes = {
+ 2: [(1,)],
+ 3: [(1, 2), (2, 1), (1,), (2,)],
+ }
+ if self.axes not in supported_axes[data_ndims]:
+ raise NotImplementedError("With the CUDA backend, batched transform is only supported along fastest dimensions")
+ self.cufft_batch_size = self.shape[0]
+ self.cufft_shape = self.shape[1:]
+ if data_ndims == 3 and len(self.axes) == 1:
+ # 1D transform on 3D data: here only supported along fast dim,
+ # so batch_size is Nx*Ny
+ self.cufft_batch_size = np.prod(self.shape[:2])
+ self.cufft_shape = (self.shape[-1],)
+ if len(self.cufft_shape) == 1:
+ self.cufft_shape = self.cufft_shape[0]
+
+ def configure_normalization(self):
+ # TODO
+ if self.normalize == "ortho":
+ raise NotImplementedError(
+ "Normalization mode 'ortho' is not implemented with CUDA backend yet."
+ )
+ self.cufft_scale_inverse = (self.normalize == "rescale")
+
+ def check_array(self, array, shape, dtype, copy=True):
+ if array.shape != shape:
+ raise ValueError("Invalid data shape: expected %s, got %s" %
+ (shape, array.shape))
+ if array.dtype != dtype:
+ raise ValueError("Invalid data type: expected %s, got %s" %
+ (dtype, array.dtype))
+
+ def set_data(self, dst, src, shape, dtype, copy=True, name=None):
+ """
+ dst is a device array owned by the current instance
+ (either self.data_in or self.data_out).
+
+ copy is ignored for device<-> arrays.
+ """
+ self.check_array(src, shape, dtype)
+ if isinstance(src, np.ndarray):
+ if name == "data_out":
+ # Makes little sense to provide output=numpy_array
+ return dst
+ if not(src.flags["C_CONTIGUOUS"]):
+ src = np.ascontiguousarray(src, dtype=dtype)
+ dst[:] = src[:]
+ elif isinstance(src, gpuarray.GPUArray):
+ # No copy, use the data as self.d_input or self.d_output
+ # (this prevents the use of in-place transforms, however).
+ # We have to keep their old references.
+ if name is None:
+ # This should not happen
+ raise ValueError("Please provide either copy=True or name != None")
+ assert id(self.refs[name]) == id(dst) # DEBUG
+ setattr(self, name, src)
+ return src
+ else:
+ raise ValueError(
+ "Invalid array type %s, expected numpy.ndarray or pycuda.gpuarray" %
+ type(src)
+ )
+ return dst
+
+ def recover_array_references(self):
+ self.data_in = self.refs["data_in"]
+ self.data_out = self.refs["data_out"]
+
+ def compute_forward_plan(self):
+ self.plan_forward = Plan(
+ self.cufft_shape,
+ self.dtype,
+ self.dtype_out,
+ batch=self.cufft_batch_size,
+ stream=self.cufft_stream,
+ # cufft extensible plan API is only supported after 0.5.1
+ # (commit 65288d28ca0b93e1234133f8d460dc6becb65121)
+ # but there is still no official 0.5.2
+ #~ auto_allocate=True # cufft extensible plan API
+ )
+
+ def compute_inverse_plan(self):
+ self.plan_inverse = Plan(
+ self.cufft_shape, # not shape_out
+ self.dtype_out,
+ self.dtype,
+ batch=self.cufft_batch_size,
+ stream=self.cufft_stream,
+ # cufft extensible plan API is only supported after 0.5.1
+ # (commit 65288d28ca0b93e1234133f8d460dc6becb65121)
+ # but there is still no official 0.5.2
+ #~ auto_allocate=True
+ )
+
+ def copy_output_if_numpy(self, dst, src):
+ if isinstance(dst, gpuarray.GPUArray):
+ return
+ dst[:] = src[:]
+
+ def fft(self, array, output=None):
+ """
+ Perform a (forward) Fast Fourier Transform.
+
+ :param Union[numpy.ndarray,pycuda.gpuarray] array:
+ Input data. Must be consistent with the current context.
+ :param Union[numpy.ndarray,pycuda.gpuarray] output:
+ Output data. By default, output is a numpy.ndarray.
+ """
+ data_in = self.set_input_data(array, copy=False)
+ data_out = self.set_output_data(output, copy=False)
+
+ cu_fft(
+ data_in,
+ data_out,
+ self.plan_forward,
+ scale=False
+ )
+
+ if output is not None:
+ self.copy_output_if_numpy(output, self.data_out)
+ res = output
+ else:
+ res = self.data_out.get()
+ self.recover_array_references()
+ return res
+
+ def ifft(self, array, output=None):
+ """
+ Perform a (inverse) Fast Fourier Transform.
+
+ :param Union[numpy.ndarray,pycuda.gpuarray] array:
+ Input data. Must be consistent with the current context.
+ :param Union[numpy.ndarray,pycuda.gpuarray] output:
+ Output data. By default, output is a numpy.ndarray.
+ """
+ data_in = self.set_output_data(array, copy=False)
+ data_out = self.set_input_data(output, copy=False)
+
+ cu_ifft(
+ data_in,
+ data_out,
+ self.plan_inverse,
+ scale=self.cufft_scale_inverse,
+ )
+
+ if output is not None:
+ self.copy_output_if_numpy(output, self.data_in)
+ res = output
+ else:
+ res = self.data_in.get()
+ self.recover_array_references()
+ return res