summaryrefslogtreecommitdiff
path: root/src/silx/math/fft/clfft.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/silx/math/fft/clfft.py')
-rw-r--r--src/silx/math/fft/clfft.py51
1 files changed, 22 insertions, 29 deletions
diff --git a/src/silx/math/fft/clfft.py b/src/silx/math/fft/clfft.py
index 2e41e47..488102a 100644
--- a/src/silx/math/fft/clfft.py
+++ b/src/silx/math/fft/clfft.py
@@ -25,12 +25,14 @@
import numpy as np
from .basefft import BaseFFT, check_version
+
try:
import pyopencl as cl
import pyopencl.array as parray
import gpyfft
from gpyfft.fft import FFT as cl_fft
from ...opencl.common import ocl
+
__have_clfft__ = True
except ImportError:
__have_clfft__ = False
@@ -58,6 +60,7 @@ class CLFFT(BaseFFT):
:param bool choose_best_device:
Whether to automatically choose the best available OpenCL device.
"""
+
def __init__(
self,
shape=None,
@@ -70,8 +73,11 @@ class CLFFT(BaseFFT):
fast_math=False,
choose_best_device=True,
):
- if not(__have_clfft__) or not(__have_clfft__):
- raise ImportError("Please install pyopencl and gpyfft >= %s to use the OpenCL back-end" % __required_gpyfft_version__)
+ if not (__have_clfft__) or not (__have_clfft__):
+ raise ImportError(
+ "Please install pyopencl and gpyfft >= %s to use the OpenCL back-end"
+ % __required_gpyfft_version__
+ )
super().__init__(
shape=shape,
@@ -116,18 +122,16 @@ class CLFFT(BaseFFT):
ary.fill(0)
return ary
-
def check_array(self, array, shape, dtype, copy=True):
if array.shape != shape:
- raise ValueError("Invalid data shape: expected %s, got %s" %
- (shape, array.shape)
+ raise ValueError(
+ "Invalid data shape: expected %s, got %s" % (shape, array.shape)
)
if array.dtype != dtype:
- raise ValueError("Invalid data type: expected %s, got %s" %
- (dtype, array.dtype)
+ raise ValueError(
+ "Invalid data type: expected %s, got %s" % (dtype, array.dtype)
)
-
def set_data(self, dst, src, shape, dtype, copy=True, name=None):
"""
dst is a device array owned by the current instance
@@ -140,10 +144,10 @@ class CLFFT(BaseFFT):
if name == "data_out":
# Makes little sense to provide output=numpy_array
return dst
- if not(src.flags["C_CONTIGUOUS"]):
+ if not (src.flags["C_CONTIGUOUS"]):
src = np.ascontiguousarray(src, dtype=dtype)
# working on underlying buffer is notably faster
- #~ dst[:] = src[:]
+ # ~ dst[:] = src[:]
evt = cl.enqueue_copy(self.queue, dst.data, src)
evt.wait()
elif isinstance(src, parray.Array):
@@ -153,22 +157,20 @@ class CLFFT(BaseFFT):
if name is None:
# This should not happen
raise ValueError("Please provide either copy=True or name != None")
- assert id(self.refs[name]) == id(dst) # DEBUG
+ assert id(self.refs[name]) == id(dst) # DEBUG
setattr(self, name, src)
return src
else:
raise ValueError(
- "Invalid array type %s, expected numpy.ndarray or pyopencl.array" %
- type(src)
+ "Invalid array type %s, expected numpy.ndarray or pyopencl.array"
+ % type(src)
)
return dst
-
def recover_array_references(self):
self.data_in = self.refs["data_in"]
self.data_out = self.refs["data_out"]
-
def init_context_queue(self):
if self.ctx is None:
if self.choose_best_device:
@@ -177,7 +179,6 @@ class CLFFT(BaseFFT):
self.ctx = cl.create_some_context()
self.queue = cl.CommandQueue(self.ctx)
-
def compute_forward_plan(self):
self.plan_forward = cl_fft(
self.ctx,
@@ -189,7 +190,6 @@ class CLFFT(BaseFFT):
real=self.real_transform,
)
-
def compute_inverse_plan(self):
self.plan_inverse = cl_fft(
self.ctx,
@@ -201,26 +201,22 @@ class CLFFT(BaseFFT):
real=self.real_transform,
)
-
def update_forward_plan_arrays(self):
self.plan_forward.data = self.data_in
self.plan_forward.result = self.data_out
-
def update_inverse_plan_arrays(self):
self.plan_inverse.data = self.data_out
self.plan_inverse.result = self.data_in
-
def copy_output_if_numpy(self, dst, src):
if isinstance(dst, parray.Array):
return
# working on underlying buffer is notably faster
- #~ dst[:] = src[:]
+ # ~ dst[:] = src[:]
evt = cl.enqueue_copy(self.queue, dst, src.data)
evt.wait()
-
def fft(self, array, output=None, do_async=False):
"""
Perform a (forward) Fast Fourier Transform.
@@ -236,8 +232,8 @@ class CLFFT(BaseFFT):
self.set_input_data(array, copy=False)
self.set_output_data(output, copy=False)
self.update_forward_plan_arrays()
- event, = self.plan_forward.enqueue()
- if not(do_async):
+ (event,) = self.plan_forward.enqueue()
+ if not (do_async):
event.wait()
if output is not None:
self.copy_output_if_numpy(output, self.data_out)
@@ -247,7 +243,6 @@ class CLFFT(BaseFFT):
self.recover_array_references()
return res
-
def ifft(self, array, output=None, do_async=False):
"""
Perform a (inverse) Fast Fourier Transform.
@@ -263,8 +258,8 @@ class CLFFT(BaseFFT):
self.set_output_data(array, copy=False)
self.set_input_data(output, copy=False)
self.update_inverse_plan_arrays()
- event, = self.plan_inverse.enqueue(forward=False)
- if not(do_async):
+ (event,) = self.plan_inverse.enqueue(forward=False)
+ if not (do_async):
event.wait()
if output is not None:
self.copy_output_if_numpy(output, self.data_in)
@@ -274,7 +269,6 @@ class CLFFT(BaseFFT):
self.recover_array_references()
return res
-
def __del__(self):
# It seems that gpyfft underlying clFFT destructors are not called.
# This results in the following warning:
@@ -282,4 +276,3 @@ class CLFFT(BaseFFT):
# Please consider explicitly calling clfftTeardown( )
del self.plan_forward
del self.plan_inverse
-