summaryrefslogtreecommitdiff
path: root/src/silx/math/fft/clfft.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/silx/math/fft/clfft.py')
-rw-r--r--src/silx/math/fft/clfft.py286
1 files changed, 286 insertions, 0 deletions
diff --git a/src/silx/math/fft/clfft.py b/src/silx/math/fft/clfft.py
new file mode 100644
index 0000000..dad8ec1
--- /dev/null
+++ b/src/silx/math/fft/clfft.py
@@ -0,0 +1,286 @@
+#!/usr/bin/env python
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+import numpy as np
+
+from .basefft import BaseFFT, check_version
+try:
+ import pyopencl as cl
+ import pyopencl.array as parray
+ import gpyfft
+ from gpyfft.fft import FFT as cl_fft
+ from ...opencl.common import ocl
+ __have_clfft__ = True
+except ImportError:
+ __have_clfft__ = False
+
+
+# Check gpyfft version
+__required_gpyfft_version__ = "0.3.0"
+if __have_clfft__:
+ __have_clfft__ = check_version(gpyfft, __required_gpyfft_version__)
+
+
+class CLFFT(BaseFFT):
+ """Initialize a clfft plan.
+
+ Please see FFT class for parameters help.
+
+ CLFFT-specific parameters
+ --------------------------
+
+ :param pyopencl.Context ctx:
+ If set to other than None, an existing pyopencl context is used.
+ :param bool fast_math:
+ If set to True, computations will be done with "fast math" mode,
+ i.e., more speed but less accuracy.
+ :param bool choose_best_device:
+ Whether to automatically choose the best available OpenCL device.
+ """
+ def __init__(
+ self,
+ shape=None,
+ dtype=None,
+ template=None,
+ shape_out=None,
+ axes=None,
+ normalize="rescale",
+ ctx=None,
+ fast_math=False,
+ choose_best_device=True,
+ ):
+ if not(__have_clfft__) or not(__have_clfft__):
+ raise ImportError("Please install pyopencl and gpyfft >= %s to use the OpenCL back-end" % __required_gpyfft_version__)
+
+ super(CLFFT, self).__init__(
+ shape=shape,
+ dtype=dtype,
+ template=template,
+ shape_out=shape_out,
+ axes=axes,
+ normalize=normalize,
+ )
+ self.ctx = ctx
+ self.choose_best_device = choose_best_device
+ self.fast_math = fast_math
+ self.backend = "clfft"
+
+ self.fix_axes()
+ self.init_context_queue()
+ self.allocate_arrays()
+ self.real_transform = np.isrealobj(self.data_in)
+ self.compute_forward_plan()
+ self.compute_inverse_plan()
+ self.refs = {
+ "data_in": self.data_in,
+ "data_out": self.data_out,
+ }
+ # TODO
+ # Either pyopencl ElementWiseKernel, or built-in clfft callbacks
+ if self.normalize != "rescale":
+ raise NotImplementedError(
+ "Normalization modes other than rescale are not implemented with OpenCL backend yet."
+ )
+
+ def fix_axes(self):
+ """
+ "Fix" axes.
+
+ clfft does not have the same convention as FFTW/cuda/numpy.
+ """
+ self.axes = self.axes[::-1]
+
+ def _allocate(self, shape, dtype):
+ ary = parray.empty(self.queue, shape, dtype=dtype)
+ ary.fill(0)
+ return ary
+
+
+ def check_array(self, array, shape, dtype, copy=True):
+ if array.shape != shape:
+ raise ValueError("Invalid data shape: expected %s, got %s" %
+ (shape, array.shape)
+ )
+ if array.dtype != dtype:
+ raise ValueError("Invalid data type: expected %s, got %s" %
+ (dtype, array.dtype)
+ )
+
+
+ def set_data(self, dst, src, shape, dtype, copy=True, name=None):
+ """
+ dst is a device array owned by the current instance
+ (either self.data_in or self.data_out).
+
+ copy is ignored for device<-> arrays.
+ """
+ self.check_array(src, shape, dtype)
+ if isinstance(src, np.ndarray):
+ if name == "data_out":
+ # Makes little sense to provide output=numpy_array
+ return dst
+ if not(src.flags["C_CONTIGUOUS"]):
+ src = np.ascontiguousarray(src, dtype=dtype)
+ # working on underlying buffer is notably faster
+ #~ dst[:] = src[:]
+ evt = cl.enqueue_copy(self.queue, dst.data, src)
+ evt.wait()
+ elif isinstance(src, parray.Array):
+ # No copy, use the data as self.d_input or self.d_output
+ # (this prevents the use of in-place transforms, however).
+ # We have to keep their old references.
+ if name is None:
+ # This should not happen
+ raise ValueError("Please provide either copy=True or name != None")
+ assert id(self.refs[name]) == id(dst) # DEBUG
+ setattr(self, name, src)
+ return src
+ else:
+ raise ValueError(
+ "Invalid array type %s, expected numpy.ndarray or pyopencl.array" %
+ type(src)
+ )
+ return dst
+
+
+ def recover_array_references(self):
+ self.data_in = self.refs["data_in"]
+ self.data_out = self.refs["data_out"]
+
+
+ def init_context_queue(self):
+ if self.ctx is None:
+ if self.choose_best_device:
+ self.ctx = ocl.create_context()
+ else:
+ self.ctx = cl.create_some_context()
+ self.queue = cl.CommandQueue(self.ctx)
+
+
+ def compute_forward_plan(self):
+ self.plan_forward = cl_fft(
+ self.ctx,
+ self.queue,
+ self.data_in,
+ out_array=self.data_out,
+ axes=self.axes,
+ fast_math=self.fast_math,
+ real=self.real_transform,
+ )
+
+
+ def compute_inverse_plan(self):
+ self.plan_inverse = cl_fft(
+ self.ctx,
+ self.queue,
+ self.data_out,
+ out_array=self.data_in,
+ axes=self.axes,
+ fast_math=self.fast_math,
+ real=self.real_transform,
+ )
+
+
+ def update_forward_plan_arrays(self):
+ self.plan_forward.data = self.data_in
+ self.plan_forward.result = self.data_out
+
+
+ def update_inverse_plan_arrays(self):
+ self.plan_inverse.data = self.data_out
+ self.plan_inverse.result = self.data_in
+
+
+ def copy_output_if_numpy(self, dst, src):
+ if isinstance(dst, parray.Array):
+ return
+ # working on underlying buffer is notably faster
+ #~ dst[:] = src[:]
+ evt = cl.enqueue_copy(self.queue, dst, src.data)
+ evt.wait()
+
+
+ def fft(self, array, output=None, do_async=False):
+ """
+ Perform a (forward) Fast Fourier Transform.
+
+ :param Union[numpy.ndarray,pyopencl.array] array:
+ Input data. Must be consistent with the current context.
+ :param Union[numpy.ndarray,pyopencl.array] output:
+ Output data. By default, output is a numpy.ndarray.
+ :param bool do_async:
+ Whether to perform operation in asynchronous mode.
+ Default is False, meaning that we wait for transform to complete.
+ """
+ self.set_input_data(array, copy=False)
+ self.set_output_data(output, copy=False)
+ self.update_forward_plan_arrays()
+ event, = self.plan_forward.enqueue()
+ if not(do_async):
+ event.wait()
+ if output is not None:
+ self.copy_output_if_numpy(output, self.data_out)
+ res = output
+ else:
+ res = self.data_out.get()
+ self.recover_array_references()
+ return res
+
+
+ def ifft(self, array, output=None, do_async=False):
+ """
+ Perform a (inverse) Fast Fourier Transform.
+
+ :param Union[numpy.ndarray,pyopencl.array] array:
+ Input data. Must be consistent with the current context.
+ :param Union[numpy.ndarray,pyopencl.array] output:
+ Output data. By default, output is a numpy.ndarray.
+ :param bool do_async:
+ Whether to perform operation in asynchronous mode.
+ Default is False, meaning that we wait for transform to complete.
+ """
+ self.set_output_data(array, copy=False)
+ self.set_input_data(output, copy=False)
+ self.update_inverse_plan_arrays()
+ event, = self.plan_inverse.enqueue(forward=False)
+ if not(do_async):
+ event.wait()
+ if output is not None:
+ self.copy_output_if_numpy(output, self.data_in)
+ res = output
+ else:
+ res = self.data_in.get()
+ self.recover_array_references()
+ return res
+
+
+ def __del__(self):
+ # It seems that gpyfft underlying clFFT destructors are not called.
+ # This results in the following warning:
+ # Warning: Program terminating, but clFFT resources not freed.
+ # Please consider explicitly calling clfftTeardown( )
+ del self.plan_forward
+ del self.plan_inverse
+