diff options
Diffstat (limited to 'src/silx/math/fft/clfft.py')
-rw-r--r-- | src/silx/math/fft/clfft.py | 51 |
1 files changed, 22 insertions, 29 deletions
diff --git a/src/silx/math/fft/clfft.py b/src/silx/math/fft/clfft.py index 2e41e47..488102a 100644 --- a/src/silx/math/fft/clfft.py +++ b/src/silx/math/fft/clfft.py @@ -25,12 +25,14 @@ import numpy as np from .basefft import BaseFFT, check_version + try: import pyopencl as cl import pyopencl.array as parray import gpyfft from gpyfft.fft import FFT as cl_fft from ...opencl.common import ocl + __have_clfft__ = True except ImportError: __have_clfft__ = False @@ -58,6 +60,7 @@ class CLFFT(BaseFFT): :param bool choose_best_device: Whether to automatically choose the best available OpenCL device. """ + def __init__( self, shape=None, @@ -70,8 +73,11 @@ class CLFFT(BaseFFT): fast_math=False, choose_best_device=True, ): - if not(__have_clfft__) or not(__have_clfft__): - raise ImportError("Please install pyopencl and gpyfft >= %s to use the OpenCL back-end" % __required_gpyfft_version__) + if not (__have_clfft__) or not (__have_clfft__): + raise ImportError( + "Please install pyopencl and gpyfft >= %s to use the OpenCL back-end" + % __required_gpyfft_version__ + ) super().__init__( shape=shape, @@ -116,18 +122,16 @@ class CLFFT(BaseFFT): ary.fill(0) return ary - def check_array(self, array, shape, dtype, copy=True): if array.shape != shape: - raise ValueError("Invalid data shape: expected %s, got %s" % - (shape, array.shape) + raise ValueError( + "Invalid data shape: expected %s, got %s" % (shape, array.shape) ) if array.dtype != dtype: - raise ValueError("Invalid data type: expected %s, got %s" % - (dtype, array.dtype) + raise ValueError( + "Invalid data type: expected %s, got %s" % (dtype, array.dtype) ) - def set_data(self, dst, src, shape, dtype, copy=True, name=None): """ dst is a device array owned by the current instance @@ -140,10 +144,10 @@ class CLFFT(BaseFFT): if name == "data_out": # Makes little sense to provide output=numpy_array return dst - if not(src.flags["C_CONTIGUOUS"]): + if not (src.flags["C_CONTIGUOUS"]): src = np.ascontiguousarray(src, dtype=dtype) # working on underlying buffer is notably faster - #~ dst[:] = src[:] + # ~ dst[:] = src[:] evt = cl.enqueue_copy(self.queue, dst.data, src) evt.wait() elif isinstance(src, parray.Array): @@ -153,22 +157,20 @@ class CLFFT(BaseFFT): if name is None: # This should not happen raise ValueError("Please provide either copy=True or name != None") - assert id(self.refs[name]) == id(dst) # DEBUG + assert id(self.refs[name]) == id(dst) # DEBUG setattr(self, name, src) return src else: raise ValueError( - "Invalid array type %s, expected numpy.ndarray or pyopencl.array" % - type(src) + "Invalid array type %s, expected numpy.ndarray or pyopencl.array" + % type(src) ) return dst - def recover_array_references(self): self.data_in = self.refs["data_in"] self.data_out = self.refs["data_out"] - def init_context_queue(self): if self.ctx is None: if self.choose_best_device: @@ -177,7 +179,6 @@ class CLFFT(BaseFFT): self.ctx = cl.create_some_context() self.queue = cl.CommandQueue(self.ctx) - def compute_forward_plan(self): self.plan_forward = cl_fft( self.ctx, @@ -189,7 +190,6 @@ class CLFFT(BaseFFT): real=self.real_transform, ) - def compute_inverse_plan(self): self.plan_inverse = cl_fft( self.ctx, @@ -201,26 +201,22 @@ class CLFFT(BaseFFT): real=self.real_transform, ) - def update_forward_plan_arrays(self): self.plan_forward.data = self.data_in self.plan_forward.result = self.data_out - def update_inverse_plan_arrays(self): self.plan_inverse.data = self.data_out self.plan_inverse.result = self.data_in - def copy_output_if_numpy(self, dst, src): if isinstance(dst, parray.Array): return # working on underlying buffer is notably faster - #~ dst[:] = src[:] + # ~ dst[:] = src[:] evt = cl.enqueue_copy(self.queue, dst, src.data) evt.wait() - def fft(self, array, output=None, do_async=False): """ Perform a (forward) Fast Fourier Transform. @@ -236,8 +232,8 @@ class CLFFT(BaseFFT): self.set_input_data(array, copy=False) self.set_output_data(output, copy=False) self.update_forward_plan_arrays() - event, = self.plan_forward.enqueue() - if not(do_async): + (event,) = self.plan_forward.enqueue() + if not (do_async): event.wait() if output is not None: self.copy_output_if_numpy(output, self.data_out) @@ -247,7 +243,6 @@ class CLFFT(BaseFFT): self.recover_array_references() return res - def ifft(self, array, output=None, do_async=False): """ Perform a (inverse) Fast Fourier Transform. @@ -263,8 +258,8 @@ class CLFFT(BaseFFT): self.set_output_data(array, copy=False) self.set_input_data(output, copy=False) self.update_inverse_plan_arrays() - event, = self.plan_inverse.enqueue(forward=False) - if not(do_async): + (event,) = self.plan_inverse.enqueue(forward=False) + if not (do_async): event.wait() if output is not None: self.copy_output_if_numpy(output, self.data_in) @@ -274,7 +269,6 @@ class CLFFT(BaseFFT): self.recover_array_references() return res - def __del__(self): # It seems that gpyfft underlying clFFT destructors are not called. # This results in the following warning: @@ -282,4 +276,3 @@ class CLFFT(BaseFFT): # Please consider explicitly calling clfftTeardown( ) del self.plan_forward del self.plan_inverse - |