diff options
Diffstat (limited to 'silx/math/fft/fftw.py')
-rw-r--r-- | silx/math/fft/fftw.py | 72 |
1 files changed, 38 insertions, 34 deletions
diff --git a/silx/math/fft/fftw.py b/silx/math/fft/fftw.py index f1249f9..ff6966c 100644 --- a/silx/math/fft/fftw.py +++ b/silx/math/fft/fftw.py @@ -82,7 +82,10 @@ class FFTW(BaseFFT): self.set_fftw_flags() self.compute_forward_plan() self.compute_inverse_plan() - + self.refs = { + "data_in": self.data_in, + "data_out": self.data_out, + } def set_fftw_flags(self): self.fftw_flags = ('FFTW_MEASURE', ) # TODO @@ -98,19 +101,10 @@ class FFTW(BaseFFT): ) self.fftw_norm_mode = self.fftw_norm_modes[self.normalize] - def _allocate(self, shape, dtype): return pyfftw.zeros_aligned(shape, dtype=dtype) - def check_array(self, array, shape, dtype, copy=True): - """ - Check that a given array is compatible with the FFTW plans, - in terms of alignment and data type. - - If the provided array does not meet any of the checks, a new array - is returned. - """ if array.shape != shape: raise ValueError("Invalid data shape: expected %s, got %s" % (shape, array.shape) @@ -119,21 +113,38 @@ class FFTW(BaseFFT): raise ValueError("Invalid data type: expected %s, got %s" % (dtype, array.dtype) ) + + def set_data(self, self_array, array, shape, dtype, copy=True, name=None): + """ + :param self_array: array owned by the current instance + (either self.data_in or self.data_out). + :type: numpy.ndarray + :param self_array: data to set + :type: numpy.ndarray + :type tuple shape: shape of the array + :param dtype: type of the array + :type: numpy.dtype + :param bool copy: should we copy the array + :param str name: name of the array + + Copies are avoided when possible. + """ + self.check_array(array, shape, dtype) + if id(self.refs[name]) == id(array): + # nothing to do: fft is performed on self.data_in or self.data_out + arr_to_use = self.refs[name] if self.check_alignment and not(pyfftw.is_byte_aligned(array)): - array2 = pyfftw.zeros_aligned(self.shape, dtype=self.dtype_in) - np.copyto(array2, array) + # If the array is not properly aligned, + # create a temp. array copy it to self.data_in or self.data_out + self_array[:] = array[:] + arr_to_use = self_array else: + # If the array is properly aligned, use it directly if copy: - array2 = np.copy(array) + arr_to_use = np.copy(array) else: - array2 = array - return array2 - - - def set_data(self, dst, src, shape, dtype, copy=True, name=None): - dst = self.check_array(src, shape, dtype, copy=copy) - return dst - + arr_to_use = array + return arr_to_use def compute_forward_plan(self): self.plan_forward = pyfftw.FFTW( @@ -149,7 +160,6 @@ class FFTW(BaseFFT): normalise_idft=self.fftw_norm_mode["normalize"], ) - def compute_inverse_plan(self): self.plan_inverse = pyfftw.FFTW( self.data_out, @@ -164,7 +174,6 @@ class FFTW(BaseFFT): normalise_idft=self.fftw_norm_mode["normalize"], ) - def fft(self, array, output=None): """ Perform a (forward) Fast Fourier Transform. @@ -174,18 +183,16 @@ class FFTW(BaseFFT): :param numpy.ndarray output: Optional output data. """ - data_in = self.set_input_data(array, copy=True) + data_in = self.set_input_data(array, copy=False) data_out = self.set_output_data(output, copy=False) + self.plan_forward.update_arrays(data_in, data_out) # execute.__call__ does both update_arrays() and normalization self.plan_forward( - input_array=data_in, - output_array=data_out, ortho=self.fftw_norm_mode["ortho"], ) - assert id(self.plan_forward.output_array) == id(self.data_out) == id(data_out) # DEBUG + self.plan_forward.update_arrays(self.refs["data_in"], self.refs["data_out"]) return data_out - def ifft(self, array, output=None): """ Perform a (inverse) Fast Fourier Transform. @@ -195,16 +202,13 @@ class FFTW(BaseFFT): :param numpy.ndarray output: Optional output data. """ - data_in = self.set_output_data(array, copy=True) + data_in = self.set_output_data(array, copy=False) data_out = self.set_input_data(output, copy=False) + self.plan_inverse.update_arrays(data_in, data_out) # execute.__call__ does both update_arrays() and normalization self.plan_inverse( - input_array=data_in, - output_array=data_out, ortho=self.fftw_norm_mode["ortho"], normalise_idft=self.fftw_norm_mode["normalize"] ) - assert id(self.plan_inverse.output_array) == id(self.data_in) == id(data_out) # DEBUG + self.plan_inverse.update_arrays(self.refs["data_out"], self.refs["data_in"]) return data_out - - |