summaryrefslogtreecommitdiff
path: root/silx/math/fft/fftw.py
diff options
context:
space:
mode:
Diffstat (limited to 'silx/math/fft/fftw.py')
-rw-r--r--silx/math/fft/fftw.py72
1 files changed, 38 insertions, 34 deletions
diff --git a/silx/math/fft/fftw.py b/silx/math/fft/fftw.py
index f1249f9..ff6966c 100644
--- a/silx/math/fft/fftw.py
+++ b/silx/math/fft/fftw.py
@@ -82,7 +82,10 @@ class FFTW(BaseFFT):
self.set_fftw_flags()
self.compute_forward_plan()
self.compute_inverse_plan()
-
+ self.refs = {
+ "data_in": self.data_in,
+ "data_out": self.data_out,
+ }
def set_fftw_flags(self):
self.fftw_flags = ('FFTW_MEASURE', ) # TODO
@@ -98,19 +101,10 @@ class FFTW(BaseFFT):
)
self.fftw_norm_mode = self.fftw_norm_modes[self.normalize]
-
def _allocate(self, shape, dtype):
return pyfftw.zeros_aligned(shape, dtype=dtype)
-
def check_array(self, array, shape, dtype, copy=True):
- """
- Check that a given array is compatible with the FFTW plans,
- in terms of alignment and data type.
-
- If the provided array does not meet any of the checks, a new array
- is returned.
- """
if array.shape != shape:
raise ValueError("Invalid data shape: expected %s, got %s" %
(shape, array.shape)
@@ -119,21 +113,38 @@ class FFTW(BaseFFT):
raise ValueError("Invalid data type: expected %s, got %s" %
(dtype, array.dtype)
)
+
+ def set_data(self, self_array, array, shape, dtype, copy=True, name=None):
+ """
+ :param self_array: array owned by the current instance
+ (either self.data_in or self.data_out).
+ :type: numpy.ndarray
+ :param self_array: data to set
+ :type: numpy.ndarray
+ :type tuple shape: shape of the array
+ :param dtype: type of the array
+ :type: numpy.dtype
+ :param bool copy: should we copy the array
+ :param str name: name of the array
+
+ Copies are avoided when possible.
+ """
+ self.check_array(array, shape, dtype)
+ if id(self.refs[name]) == id(array):
+ # nothing to do: fft is performed on self.data_in or self.data_out
+ arr_to_use = self.refs[name]
if self.check_alignment and not(pyfftw.is_byte_aligned(array)):
- array2 = pyfftw.zeros_aligned(self.shape, dtype=self.dtype_in)
- np.copyto(array2, array)
+ # If the array is not properly aligned,
+ # create a temp. array copy it to self.data_in or self.data_out
+ self_array[:] = array[:]
+ arr_to_use = self_array
else:
+ # If the array is properly aligned, use it directly
if copy:
- array2 = np.copy(array)
+ arr_to_use = np.copy(array)
else:
- array2 = array
- return array2
-
-
- def set_data(self, dst, src, shape, dtype, copy=True, name=None):
- dst = self.check_array(src, shape, dtype, copy=copy)
- return dst
-
+ arr_to_use = array
+ return arr_to_use
def compute_forward_plan(self):
self.plan_forward = pyfftw.FFTW(
@@ -149,7 +160,6 @@ class FFTW(BaseFFT):
normalise_idft=self.fftw_norm_mode["normalize"],
)
-
def compute_inverse_plan(self):
self.plan_inverse = pyfftw.FFTW(
self.data_out,
@@ -164,7 +174,6 @@ class FFTW(BaseFFT):
normalise_idft=self.fftw_norm_mode["normalize"],
)
-
def fft(self, array, output=None):
"""
Perform a (forward) Fast Fourier Transform.
@@ -174,18 +183,16 @@ class FFTW(BaseFFT):
:param numpy.ndarray output:
Optional output data.
"""
- data_in = self.set_input_data(array, copy=True)
+ data_in = self.set_input_data(array, copy=False)
data_out = self.set_output_data(output, copy=False)
+ self.plan_forward.update_arrays(data_in, data_out)
# execute.__call__ does both update_arrays() and normalization
self.plan_forward(
- input_array=data_in,
- output_array=data_out,
ortho=self.fftw_norm_mode["ortho"],
)
- assert id(self.plan_forward.output_array) == id(self.data_out) == id(data_out) # DEBUG
+ self.plan_forward.update_arrays(self.refs["data_in"], self.refs["data_out"])
return data_out
-
def ifft(self, array, output=None):
"""
Perform a (inverse) Fast Fourier Transform.
@@ -195,16 +202,13 @@ class FFTW(BaseFFT):
:param numpy.ndarray output:
Optional output data.
"""
- data_in = self.set_output_data(array, copy=True)
+ data_in = self.set_output_data(array, copy=False)
data_out = self.set_input_data(output, copy=False)
+ self.plan_inverse.update_arrays(data_in, data_out)
# execute.__call__ does both update_arrays() and normalization
self.plan_inverse(
- input_array=data_in,
- output_array=data_out,
ortho=self.fftw_norm_mode["ortho"],
normalise_idft=self.fftw_norm_mode["normalize"]
)
- assert id(self.plan_inverse.output_array) == id(self.data_in) == id(data_out) # DEBUG
+ self.plan_inverse.update_arrays(self.refs["data_out"], self.refs["data_in"])
return data_out
-
-