diff options
Diffstat (limited to 'silx/math/fft/fftw.py')
-rw-r--r-- | silx/math/fft/fftw.py | 210 |
1 files changed, 210 insertions, 0 deletions
diff --git a/silx/math/fft/fftw.py b/silx/math/fft/fftw.py new file mode 100644 index 0000000..f1249f9 --- /dev/null +++ b/silx/math/fft/fftw.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +import numpy as np + +from .basefft import BaseFFT, check_version +try: + import pyfftw + __have_fftw__ = True +except ImportError: + __have_fftw__ = False + + +# Check pyfftw version +__required_pyfftw_version__ = "0.10.0" +if __have_fftw__: + __have_fftw__ = check_version(pyfftw, __required_pyfftw_version__) + + +class FFTW(BaseFFT): + """Initialize a FFTW plan. + + Please see FFT class for parameters help. + + FFTW-specific parameters + ------------------------- + + :param bool check_alignment: + If set to True and "data" is provided, this will enforce the input data + to be "byte aligned", which might imply extra memory usage. + :param int num_threads: + Number of threads for computing FFT. + """ + def __init__( + self, + shape=None, + dtype=None, + template=None, + shape_out=None, + axes=None, + normalize="rescale", + check_alignment=False, + num_threads=1, + ): + if not(__have_fftw__): + raise ImportError("Please install pyfftw >= %s to use the FFTW back-end" % __required_pyfftw_version__) + super(FFTW, self).__init__( + shape=shape, + dtype=dtype, + template=template, + shape_out=shape_out, + axes=axes, + normalize=normalize, + ) + self.check_alignment = check_alignment + self.num_threads = num_threads + self.backend = "fftw" + + self.allocate_arrays() + self.set_fftw_flags() + self.compute_forward_plan() + self.compute_inverse_plan() + + + def set_fftw_flags(self): + self.fftw_flags = ('FFTW_MEASURE', ) # TODO + self.fftw_planning_timelimit = None # TODO + self.fftw_norm_modes = { + "rescale": {"ortho": False, "normalize": True}, + "ortho": {"ortho": True, "normalize": False}, + "none": {"ortho": False, "normalize": False}, + } + if self.normalize not in self.fftw_norm_modes: + raise ValueError("Unknown normalization mode %s. Possible values are %s" % + (self.normalize, self.fftw_norm_modes.keys()) + ) + self.fftw_norm_mode = self.fftw_norm_modes[self.normalize] + + + def _allocate(self, shape, dtype): + return pyfftw.zeros_aligned(shape, dtype=dtype) + + + def check_array(self, array, shape, dtype, copy=True): + """ + Check that a given array is compatible with the FFTW plans, + in terms of alignment and data type. + + If the provided array does not meet any of the checks, a new array + is returned. + """ + if array.shape != shape: + raise ValueError("Invalid data shape: expected %s, got %s" % + (shape, array.shape) + ) + if array.dtype != dtype: + raise ValueError("Invalid data type: expected %s, got %s" % + (dtype, array.dtype) + ) + if self.check_alignment and not(pyfftw.is_byte_aligned(array)): + array2 = pyfftw.zeros_aligned(self.shape, dtype=self.dtype_in) + np.copyto(array2, array) + else: + if copy: + array2 = np.copy(array) + else: + array2 = array + return array2 + + + def set_data(self, dst, src, shape, dtype, copy=True, name=None): + dst = self.check_array(src, shape, dtype, copy=copy) + return dst + + + def compute_forward_plan(self): + self.plan_forward = pyfftw.FFTW( + self.data_in, + self.data_out, + axes=self.axes, + direction='FFTW_FORWARD', + flags=self.fftw_flags, + threads=self.num_threads, + planning_timelimit=self.fftw_planning_timelimit, + # the following seems to be taken into account only when using __call__ + ortho=self.fftw_norm_mode["ortho"], + normalise_idft=self.fftw_norm_mode["normalize"], + ) + + + def compute_inverse_plan(self): + self.plan_inverse = pyfftw.FFTW( + self.data_out, + self.data_in, + axes=self.axes, + direction='FFTW_BACKWARD', + flags=self.fftw_flags, + threads=self.num_threads, + planning_timelimit=self.fftw_planning_timelimit, + # the following seem to be taken into account only when using __call__ + ortho=self.fftw_norm_mode["ortho"], + normalise_idft=self.fftw_norm_mode["normalize"], + ) + + + def fft(self, array, output=None): + """ + Perform a (forward) Fast Fourier Transform. + + :param numpy.ndarray array: + Input data. Must be consistent with the current context. + :param numpy.ndarray output: + Optional output data. + """ + data_in = self.set_input_data(array, copy=True) + data_out = self.set_output_data(output, copy=False) + # execute.__call__ does both update_arrays() and normalization + self.plan_forward( + input_array=data_in, + output_array=data_out, + ortho=self.fftw_norm_mode["ortho"], + ) + assert id(self.plan_forward.output_array) == id(self.data_out) == id(data_out) # DEBUG + return data_out + + + def ifft(self, array, output=None): + """ + Perform a (inverse) Fast Fourier Transform. + + :param numpy.ndarray array: + Input data. Must be consistent with the current context. + :param numpy.ndarray output: + Optional output data. + """ + data_in = self.set_output_data(array, copy=True) + data_out = self.set_input_data(output, copy=False) + # execute.__call__ does both update_arrays() and normalization + self.plan_inverse( + input_array=data_in, + output_array=data_out, + ortho=self.fftw_norm_mode["ortho"], + normalise_idft=self.fftw_norm_mode["normalize"] + ) + assert id(self.plan_inverse.output_array) == id(self.data_in) == id(data_out) # DEBUG + return data_out + + |