diff options
Diffstat (limited to 'silx/math/fft')
-rw-r--r-- | silx/math/fft/__init__.py | 8 | ||||
-rw-r--r-- | silx/math/fft/basefft.py | 149 | ||||
-rw-r--r-- | silx/math/fft/clfft.py | 284 | ||||
-rw-r--r-- | silx/math/fft/cufft.py | 253 | ||||
-rw-r--r-- | silx/math/fft/fft.py | 96 | ||||
-rw-r--r-- | silx/math/fft/fftw.py | 210 | ||||
-rw-r--r-- | silx/math/fft/npfft.py | 124 | ||||
-rw-r--r-- | silx/math/fft/setup.py | 41 | ||||
-rw-r--r-- | silx/math/fft/test/__init__.py | 28 | ||||
-rw-r--r-- | silx/math/fft/test/test_fft.py | 338 |
10 files changed, 1531 insertions, 0 deletions
diff --git a/silx/math/fft/__init__.py b/silx/math/fft/__init__.py new file mode 100644 index 0000000..ea12cd6 --- /dev/null +++ b/silx/math/fft/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# coding: utf-8 + +__authors__ = ["P. Paleo"] +__license__ = "MIT" +__date__ = "12/12/2018" + +from .fft import FFT diff --git a/silx/math/fft/basefft.py b/silx/math/fft/basefft.py new file mode 100644 index 0000000..d0718e6 --- /dev/null +++ b/silx/math/fft/basefft.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +import numpy as np +from pkg_resources import parse_version + + +def check_version(package, required_version): + """ + Check whether a given package version is superior or equal to required_version. + """ + try: + ver = getattr(package, "__version__") + except AttributeError: + try: + ver = getattr(package, "version") + except Exception: + return False + req_v = parse_version(required_version) + ver_v = parse_version(ver) + return ver_v >= req_v + + +class BaseFFT(object): + """ + Base class for all FFT backends. + """ + def __init__(self, **kwargs): + self.__get_args(**kwargs) + + if self.shape is None and self.dtype is None and self.template is None: + raise ValueError("Please provide either (shape and dtype) or template") + if self.template is not None: + self.shape = self.template.shape + self.dtype = self.template.dtype + self.user_data = self.template + self.data_allocated = False + self.__calc_axes() + self.__set_dtypes() + self.__calc_shape() + + + def __get_args(self, **kwargs): + expected_args = { + "shape": None, + "dtype": None, + "template": None, + "shape_out": None, + "axes": None, + "normalize": "rescale", + } + for arg_name, default_val in expected_args.items(): + if arg_name not in kwargs: + # Base class was not instantiated properly + raise ValueError("Please provide argument %s" % arg_name) + setattr(self, arg_name, default_val) + for arg_name, arg_val in kwargs.items(): + setattr(self, arg_name, arg_val) + + def __set_dtypes(self): + dtypes_mapping = { + np.dtype("float32"): np.complex64, + np.dtype("float64"): np.complex128, + np.dtype("complex64"): np.complex64, + np.dtype("complex128"): np.complex128 + } + dp = { + np.dtype("float32"): np.float64, + np.dtype("complex64"): np.complex128 + } + self.dtype_in = np.dtype(self.dtype) + if self.dtype_in not in dtypes_mapping: + raise ValueError("Invalid input data type: got %s" % + self.dtype_in + ) + self.dtype_out = dtypes_mapping[self.dtype_in] + + + def __calc_shape(self): + # TODO allow for C2C even for real input data (?) + if self.dtype_in in [np.float32, np.float64]: + last_dim = self.shape[-1]//2 + 1 + # FFTW convention + self.shape_out = self.shape[:-1] + (self.shape[-1]//2 + 1,) + else: + self.shape_out = self.shape + + def __calc_axes(self): + default_axes = tuple(range(len(self.shape))) + if self.axes is None: + self.axes = default_axes + self.user_axes = None + else: + self.user_axes = self.axes + # Handle possibly negative axes + self.axes = tuple(np.array(default_axes)[np.array(self.user_axes)]) + + def _allocate(self, shape, dtype): + raise ValueError("This should be implemented by back-end FFT") + + def set_data(self, dst, src, shape, dtype, copy=True): + raise ValueError("This should be implemented by back-end FFT") + + def allocate_arrays(self): + if not(self.data_allocated): + self.data_in = self._allocate(self.shape, self.dtype_in) + self.data_out = self._allocate(self.shape_out, self.dtype_out) + self.data_allocated = True + + def set_input_data(self, data, copy=True): + if data is None: + return self.data_in + else: + return self.set_data(self.data_in, data, self.shape, self.dtype_in, copy=copy, name="data_in") + + def set_output_data(self, data, copy=True): + if data is None: + return self.data_out + else: + return self.set_data(self.data_out, data, self.shape_out, self.dtype_out, copy=copy, name="data_out") + + def fft(self, array, **kwargs): + raise ValueError("This should be implemented by back-end FFT") + + def ifft(self, array, **kwargs): + raise ValueError("This should be implemented by back-end FFT") + diff --git a/silx/math/fft/clfft.py b/silx/math/fft/clfft.py new file mode 100644 index 0000000..af75aa6 --- /dev/null +++ b/silx/math/fft/clfft.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +import numpy as np + +from .basefft import BaseFFT, check_version +try: + import pyopencl as cl + import pyopencl.array as parray + import gpyfft + from gpyfft.fft import FFT as cl_fft + from ...opencl.common import ocl + __have_clfft__ = True +except ImportError: + __have_clfft__ = False + + +# Check gpyfft version +__required_gpyfft_version__ = "0.3.0" +if __have_clfft__: + __have_clfft__ = check_version(gpyfft, __required_gpyfft_version__) + + +class CLFFT(BaseFFT): + """Initialize a clfft plan. + + Please see FFT class for parameters help. + + CLFFT-specific parameters + -------------------------- + + :param pyopencl.Context ctx: + If set to other than None, an existing pyopencl context is used. + :param bool fast_math: + If set to True, computations will be done with "fast math" mode, + i.e., more speed but less accuracy. + :param bool choose_best_device: + Whether to automatically choose the best available OpenCL device. + """ + def __init__( + self, + shape=None, + dtype=None, + template=None, + shape_out=None, + axes=None, + normalize="rescale", + ctx=None, + fast_math=False, + choose_best_device=True, + ): + if not(__have_clfft__) or not(__have_clfft__): + raise ImportError("Please install pyopencl and gpyfft >= %s to use the OpenCL back-end" % __required_gpyfft_version__) + + super(CLFFT, self).__init__( + shape=shape, + dtype=dtype, + template=template, + shape_out=shape_out, + axes=axes, + normalize=normalize, + ) + self.ctx = ctx + self.choose_best_device = choose_best_device + self.fast_math = fast_math + self.backend = "clfft" + + self.fix_axes() + self.init_context_queue() + self.allocate_arrays() + self.real_transform = np.isrealobj(self.data_in) + self.compute_forward_plan() + self.compute_inverse_plan() + self.refs = { + "data_in": self.data_in, + "data_out": self.data_out, + } + # TODO + # Either pyopencl ElementWiseKernel, or built-in clfft callbacks + if self.normalize != "rescale": + raise NotImplementedError( + "Normalization modes other than rescale are not implemented with OpenCL backend yet." + ) + + def fix_axes(self): + """ + "Fix" axes. + + clfft does not have the same convention as FFTW/cuda/numpy. + """ + self.axes = self.axes[::-1] + + def _allocate(self, shape, dtype): + return parray.zeros(self.queue, shape, dtype=dtype) + + + def check_array(self, array, shape, dtype, copy=True): + if array.shape != shape: + raise ValueError("Invalid data shape: expected %s, got %s" % + (shape, array.shape) + ) + if array.dtype != dtype: + raise ValueError("Invalid data type: expected %s, got %s" % + (dtype, array.dtype) + ) + + + def set_data(self, dst, src, shape, dtype, copy=True, name=None): + """ + dst is a device array owned by the current instance + (either self.data_in or self.data_out). + + copy is ignored for device<-> arrays. + """ + self.check_array(src, shape, dtype) + if isinstance(src, np.ndarray): + if name == "data_out": + # Makes little sense to provide output=numpy_array + return dst + if not(src.flags["C_CONTIGUOUS"]): + src = np.ascontiguousarray(src, dtype=dtype) + # working on underlying buffer is notably faster + #~ dst[:] = src[:] + evt = cl.enqueue_copy(self.queue, dst.data, src) + evt.wait() + elif isinstance(src, parray.Array): + # No copy, use the data as self.d_input or self.d_output + # (this prevents the use of in-place transforms, however). + # We have to keep their old references. + if name is None: + # This should not happen + raise ValueError("Please provide either copy=True or name != None") + assert id(self.refs[name]) == id(dst) # DEBUG + setattr(self, name, src) + return src + else: + raise ValueError( + "Invalid array type %s, expected numpy.ndarray or pyopencl.array" % + type(src) + ) + return dst + + + def recover_array_references(self): + self.data_in = self.refs["data_in"] + self.data_out = self.refs["data_out"] + + + def init_context_queue(self): + if self.ctx is None: + if self.choose_best_device: + self.ctx = ocl.create_context() + else: + self.ctx = cl.create_some_context() + self.queue = cl.CommandQueue(self.ctx) + + + def compute_forward_plan(self): + self.plan_forward = cl_fft( + self.ctx, + self.queue, + self.data_in, + out_array=self.data_out, + axes=self.axes, + fast_math=self.fast_math, + real=self.real_transform, + ) + + + def compute_inverse_plan(self): + self.plan_inverse = cl_fft( + self.ctx, + self.queue, + self.data_out, + out_array=self.data_in, + axes=self.axes, + fast_math=self.fast_math, + real=self.real_transform, + ) + + + def update_forward_plan_arrays(self): + self.plan_forward.data = self.data_in + self.plan_forward.result = self.data_out + + + def update_inverse_plan_arrays(self): + self.plan_inverse.data = self.data_out + self.plan_inverse.result = self.data_in + + + def copy_output_if_numpy(self, dst, src): + if isinstance(dst, parray.Array): + return + # working on underlying buffer is notably faster + #~ dst[:] = src[:] + evt = cl.enqueue_copy(self.queue, dst, src.data) + evt.wait() + + + def fft(self, array, output=None, do_async=False): + """ + Perform a (forward) Fast Fourier Transform. + + :param Union[numpy.ndarray,pyopencl.array] array: + Input data. Must be consistent with the current context. + :param Union[numpy.ndarray,pyopencl.array] output: + Output data. By default, output is a numpy.ndarray. + :param bool do_async: + Whether to perform operation in asynchronous mode. + Default is False, meaning that we wait for transform to complete. + """ + data_in = self.set_input_data(array, copy=False) + data_out = self.set_output_data(output, copy=False) + self.update_forward_plan_arrays() + event, = self.plan_forward.enqueue() + if not(do_async): + event.wait() + if output is not None: + self.copy_output_if_numpy(output, self.data_out) + res = output + else: + res = self.data_out.get() + self.recover_array_references() + return res + + + def ifft(self, array, output=None, do_async=False): + """ + Perform a (inverse) Fast Fourier Transform. + + :param Union[numpy.ndarray,pyopencl.array] array: + Input data. Must be consistent with the current context. + :param Union[numpy.ndarray,pyopencl.array] output: + Output data. By default, output is a numpy.ndarray. + :param bool do_async: + Whether to perform operation in asynchronous mode. + Default is False, meaning that we wait for transform to complete. + """ + data_in = self.set_output_data(array, copy=False) + data_out = self.set_input_data(output, copy=False) + self.update_inverse_plan_arrays() + event, = self.plan_inverse.enqueue(forward=False) + if not(do_async): + event.wait() + if output is not None: + self.copy_output_if_numpy(output, self.data_in) + res = output + else: + res = self.data_in.get() + self.recover_array_references() + return res + + + def __del__(self): + # It seems that gpyfft underlying clFFT destructors are not called. + # This results in the following warning: + # Warning: Program terminating, but clFFT resources not freed. + # Please consider explicitly calling clfftTeardown( ) + del self.plan_forward + del self.plan_inverse + diff --git a/silx/math/fft/cufft.py b/silx/math/fft/cufft.py new file mode 100644 index 0000000..848f3e6 --- /dev/null +++ b/silx/math/fft/cufft.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +import numpy as np + +from .basefft import BaseFFT +try: + import pycuda.gpuarray as gpuarray + from skcuda.fft import Plan + from skcuda.fft import fft as cu_fft + from skcuda.fft import ifft as cu_ifft + __have_cufft__ = True +except ImportError: + __have_cufft__ = False + + +class CUFFT(BaseFFT): + """Initialize a cufft plan + + Please see FFT class for parameters help. + + CUFFT-specific parameters + -------------------------- + + :param pycuda.driver.Stream stream: + Stream with which to associate the plan. If no stream is specified, + the default stream is used. + """ + def __init__( + self, + shape=None, + dtype=None, + template=None, + shape_out=None, + axes=None, + normalize="rescale", + stream=None, + ): + if not(__have_cufft__) or not(__have_cufft__): + raise ImportError("Please install pycuda and scikit-cuda to use the CUDA back-end") + + super(CUFFT, self).__init__( + shape=shape, + dtype=dtype, + template=template, + shape_out=shape_out, + axes=axes, + normalize=normalize, + ) + self.cufft_stream = stream + self.backend = "cufft" + + self.configure_batched_transform() + self.allocate_arrays() + self.real_transform = np.isrealobj(self.data_in) + self.compute_forward_plan() + self.compute_inverse_plan() + self.refs = { + "data_in": self.data_in, + "data_out": self.data_out, + } + self.configure_normalization() + + def _allocate(self, shape, dtype): + return gpuarray.zeros(shape, dtype) + + # TODO support batched transform where batch is other than dimension 0 + def configure_batched_transform(self): + self.cufft_batch_size = 1 + self.cufft_shape = self.shape + if (self.axes is not None) and (len(self.axes) < len(self.shape)): + # In the easiest case, the transform is computed along the fastest dimensions: + # - 1D transforms of lines of 2D data + # - 2D transforms of images of 3D data (stacked along slow dim) + # - 1D transforms of 3D data along fastest dim + # Otherwise, we have to configure cuda "advanced memory layout", + # which is not implemented yet. + + data_ndims = len(self.shape) + supported_axes = { + 2: [(1,)], + 3: [(1, 2), (2, 1), (1,), (2,)], + } + if self.axes not in supported_axes[data_ndims]: + raise NotImplementedError("With the CUDA backend, batched transform is only supported along fastest dimensions") + self.cufft_batch_size = self.shape[0] + self.cufft_shape = self.shape[1:] + if data_ndims == 3 and len(self.axes) == 1: + # 1D transform on 3D data: here only supported along fast dim, + # so batch_size is Nx*Ny + self.cufft_batch_size = np.prod(self.shape[:2]) + self.cufft_shape = (self.shape[-1],) + if len(self.cufft_shape) == 1: + self.cufft_shape = self.cufft_shape[0] + + def configure_normalization(self): + # TODO + if self.normalize == "ortho": + raise NotImplementedError( + "Normalization mode 'ortho' is not implemented with CUDA backend yet." + ) + self.cufft_scale_inverse = (self.normalize == "rescale") + + def check_array(self, array, shape, dtype, copy=True): + if array.shape != shape: + raise ValueError("Invalid data shape: expected %s, got %s" % + (shape, array.shape)) + if array.dtype != dtype: + raise ValueError("Invalid data type: expected %s, got %s" % + (dtype, array.dtype)) + + def set_data(self, dst, src, shape, dtype, copy=True, name=None): + """ + dst is a device array owned by the current instance + (either self.data_in or self.data_out). + + copy is ignored for device<-> arrays. + """ + self.check_array(src, shape, dtype) + if isinstance(src, np.ndarray): + if name == "data_out": + # Makes little sense to provide output=numpy_array + return dst + if not(src.flags["C_CONTIGUOUS"]): + src = np.ascontiguousarray(src, dtype=dtype) + dst[:] = src[:] + elif isinstance(src, gpuarray.GPUArray): + # No copy, use the data as self.d_input or self.d_output + # (this prevents the use of in-place transforms, however). + # We have to keep their old references. + if name is None: + # This should not happen + raise ValueError("Please provide either copy=True or name != None") + assert id(self.refs[name]) == id(dst) # DEBUG + setattr(self, name, src) + return src + else: + raise ValueError( + "Invalid array type %s, expected numpy.ndarray or pycuda.gpuarray" % + type(src) + ) + return dst + + def recover_array_references(self): + self.data_in = self.refs["data_in"] + self.data_out = self.refs["data_out"] + + def compute_forward_plan(self): + self.plan_forward = Plan( + self.cufft_shape, + self.dtype, + self.dtype_out, + batch=self.cufft_batch_size, + stream=self.cufft_stream, + # cufft extensible plan API is only supported after 0.5.1 + # (commit 65288d28ca0b93e1234133f8d460dc6becb65121) + # but there is still no official 0.5.2 + #~ auto_allocate=True # cufft extensible plan API + ) + + def compute_inverse_plan(self): + self.plan_inverse = Plan( + self.cufft_shape, # not shape_out + self.dtype_out, + self.dtype, + batch=self.cufft_batch_size, + stream=self.cufft_stream, + # cufft extensible plan API is only supported after 0.5.1 + # (commit 65288d28ca0b93e1234133f8d460dc6becb65121) + # but there is still no official 0.5.2 + #~ auto_allocate=True + ) + + def copy_output_if_numpy(self, dst, src): + if isinstance(dst, gpuarray.GPUArray): + return + dst[:] = src[:] + + def fft(self, array, output=None): + """ + Perform a (forward) Fast Fourier Transform. + + :param Union[numpy.ndarray,pycuda.gpuarray] array: + Input data. Must be consistent with the current context. + :param Union[numpy.ndarray,pycuda.gpuarray] output: + Output data. By default, output is a numpy.ndarray. + """ + data_in = self.set_input_data(array, copy=False) + data_out = self.set_output_data(output, copy=False) + + cu_fft( + data_in, + data_out, + self.plan_forward, + scale=False + ) + + if output is not None: + self.copy_output_if_numpy(output, self.data_out) + res = output + else: + res = self.data_out.get() + self.recover_array_references() + return res + + def ifft(self, array, output=None): + """ + Perform a (inverse) Fast Fourier Transform. + + :param Union[numpy.ndarray,pycuda.gpuarray] array: + Input data. Must be consistent with the current context. + :param Union[numpy.ndarray,pycuda.gpuarray] output: + Output data. By default, output is a numpy.ndarray. + """ + data_in = self.set_output_data(array, copy=False) + data_out = self.set_input_data(output, copy=False) + + cu_ifft( + data_in, + data_out, + self.plan_inverse, + scale=self.cufft_scale_inverse, + ) + + if output is not None: + self.copy_output_if_numpy(output, self.data_in) + res = output + else: + res = self.data_in.get() + self.recover_array_references() + return res diff --git a/silx/math/fft/fft.py b/silx/math/fft/fft.py new file mode 100644 index 0000000..eb0d73b --- /dev/null +++ b/silx/math/fft/fft.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +from .fftw import FFTW +from .clfft import CLFFT +from .npfft import NPFFT +from .cufft import CUFFT + + +def FFT( + shape=None, + dtype=None, + template=None, + shape_out=None, + axes=None, + normalize="rescale", + backend="numpy", + **kwargs +): + """ + Initialize a FFT plan. + + :param List[int] shape: + Shape of the input data. + :param numpy.dtype dtype: + Data type of the input data. + :param numpy.ndarray template: + Optional data, replacement for "shape" and "dtype". + If provided, the arguments "shape" and "dtype" are ignored, + and are instead inferred from it. + :param List[int] shape_out: + Optional shape of output data. + By default, the data has the same shape as the input + data (in case of C2C transform), or a shape with the last dimension halved + (in case of R2C transform). If shape_out is provided, it must be greater + or equal than the shape of input data. In this case, FFT is performed + with zero-padding. + :param List[int] axes: + Axes along which FFT is computed. + * For 2D transform: axes=(1,0) + * For batched 1D transform of 2D image: axes=(0,) + :param str normalize: + Whether to normalize FFT and IFFT. Possible values are: + * "rescale": in this case, Fourier data is divided by "N" + before IFFT, so that (FFT(data)) = data + * "ortho": in this case, FFT and IFFT are adjoint of eachother, + the transform is unitary. Both FFT and IFFT are scaled with 1/sqrt(N). + * "none": no normalizatio is done : IFFT(FFT(data)) = data*N + :param str backend: + FFT Backend to use. Value can be "numpy", "fftw", "opencl", "cuda". + """ + backends = { + "numpy": NPFFT, + "np": NPFFT, + "fftw": FFTW, + "opencl": CLFFT, + "clfft": CLFFT, + "cuda": CUFFT, + "cufft": CUFFT, + } + + backend = backend.lower() + if backend not in backends: + raise ValueError("Unknown backend %s, available are %s" % (backend, backends)) + F = backends[backend]( + shape=shape, + dtype=dtype, + template=template, + shape_out=shape_out, + axes=axes, + normalize=normalize, + **kwargs + ) + return F diff --git a/silx/math/fft/fftw.py b/silx/math/fft/fftw.py new file mode 100644 index 0000000..f1249f9 --- /dev/null +++ b/silx/math/fft/fftw.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +import numpy as np + +from .basefft import BaseFFT, check_version +try: + import pyfftw + __have_fftw__ = True +except ImportError: + __have_fftw__ = False + + +# Check pyfftw version +__required_pyfftw_version__ = "0.10.0" +if __have_fftw__: + __have_fftw__ = check_version(pyfftw, __required_pyfftw_version__) + + +class FFTW(BaseFFT): + """Initialize a FFTW plan. + + Please see FFT class for parameters help. + + FFTW-specific parameters + ------------------------- + + :param bool check_alignment: + If set to True and "data" is provided, this will enforce the input data + to be "byte aligned", which might imply extra memory usage. + :param int num_threads: + Number of threads for computing FFT. + """ + def __init__( + self, + shape=None, + dtype=None, + template=None, + shape_out=None, + axes=None, + normalize="rescale", + check_alignment=False, + num_threads=1, + ): + if not(__have_fftw__): + raise ImportError("Please install pyfftw >= %s to use the FFTW back-end" % __required_pyfftw_version__) + super(FFTW, self).__init__( + shape=shape, + dtype=dtype, + template=template, + shape_out=shape_out, + axes=axes, + normalize=normalize, + ) + self.check_alignment = check_alignment + self.num_threads = num_threads + self.backend = "fftw" + + self.allocate_arrays() + self.set_fftw_flags() + self.compute_forward_plan() + self.compute_inverse_plan() + + + def set_fftw_flags(self): + self.fftw_flags = ('FFTW_MEASURE', ) # TODO + self.fftw_planning_timelimit = None # TODO + self.fftw_norm_modes = { + "rescale": {"ortho": False, "normalize": True}, + "ortho": {"ortho": True, "normalize": False}, + "none": {"ortho": False, "normalize": False}, + } + if self.normalize not in self.fftw_norm_modes: + raise ValueError("Unknown normalization mode %s. Possible values are %s" % + (self.normalize, self.fftw_norm_modes.keys()) + ) + self.fftw_norm_mode = self.fftw_norm_modes[self.normalize] + + + def _allocate(self, shape, dtype): + return pyfftw.zeros_aligned(shape, dtype=dtype) + + + def check_array(self, array, shape, dtype, copy=True): + """ + Check that a given array is compatible with the FFTW plans, + in terms of alignment and data type. + + If the provided array does not meet any of the checks, a new array + is returned. + """ + if array.shape != shape: + raise ValueError("Invalid data shape: expected %s, got %s" % + (shape, array.shape) + ) + if array.dtype != dtype: + raise ValueError("Invalid data type: expected %s, got %s" % + (dtype, array.dtype) + ) + if self.check_alignment and not(pyfftw.is_byte_aligned(array)): + array2 = pyfftw.zeros_aligned(self.shape, dtype=self.dtype_in) + np.copyto(array2, array) + else: + if copy: + array2 = np.copy(array) + else: + array2 = array + return array2 + + + def set_data(self, dst, src, shape, dtype, copy=True, name=None): + dst = self.check_array(src, shape, dtype, copy=copy) + return dst + + + def compute_forward_plan(self): + self.plan_forward = pyfftw.FFTW( + self.data_in, + self.data_out, + axes=self.axes, + direction='FFTW_FORWARD', + flags=self.fftw_flags, + threads=self.num_threads, + planning_timelimit=self.fftw_planning_timelimit, + # the following seems to be taken into account only when using __call__ + ortho=self.fftw_norm_mode["ortho"], + normalise_idft=self.fftw_norm_mode["normalize"], + ) + + + def compute_inverse_plan(self): + self.plan_inverse = pyfftw.FFTW( + self.data_out, + self.data_in, + axes=self.axes, + direction='FFTW_BACKWARD', + flags=self.fftw_flags, + threads=self.num_threads, + planning_timelimit=self.fftw_planning_timelimit, + # the following seem to be taken into account only when using __call__ + ortho=self.fftw_norm_mode["ortho"], + normalise_idft=self.fftw_norm_mode["normalize"], + ) + + + def fft(self, array, output=None): + """ + Perform a (forward) Fast Fourier Transform. + + :param numpy.ndarray array: + Input data. Must be consistent with the current context. + :param numpy.ndarray output: + Optional output data. + """ + data_in = self.set_input_data(array, copy=True) + data_out = self.set_output_data(output, copy=False) + # execute.__call__ does both update_arrays() and normalization + self.plan_forward( + input_array=data_in, + output_array=data_out, + ortho=self.fftw_norm_mode["ortho"], + ) + assert id(self.plan_forward.output_array) == id(self.data_out) == id(data_out) # DEBUG + return data_out + + + def ifft(self, array, output=None): + """ + Perform a (inverse) Fast Fourier Transform. + + :param numpy.ndarray array: + Input data. Must be consistent with the current context. + :param numpy.ndarray output: + Optional output data. + """ + data_in = self.set_output_data(array, copy=True) + data_out = self.set_input_data(output, copy=False) + # execute.__call__ does both update_arrays() and normalization + self.plan_inverse( + input_array=data_in, + output_array=data_out, + ortho=self.fftw_norm_mode["ortho"], + normalise_idft=self.fftw_norm_mode["normalize"] + ) + assert id(self.plan_inverse.output_array) == id(self.data_in) == id(data_out) # DEBUG + return data_out + + diff --git a/silx/math/fft/npfft.py b/silx/math/fft/npfft.py new file mode 100644 index 0000000..20351de --- /dev/null +++ b/silx/math/fft/npfft.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +import numpy as np + +from .basefft import BaseFFT + + +class NPFFT(BaseFFT): + """Initialize a numpy plan. + + Please see FFT class for parameters help. + """ + def __init__( + self, + shape=None, + dtype=None, + template=None, + shape_out=None, + axes=None, + normalize="rescale", + ): + super(NPFFT, self).__init__( + shape=shape, + dtype=dtype, + template=template, + shape_out=shape_out, + axes=axes, + normalize=normalize, + ) + self.backend = "numpy" + self.real_transform = False + if template is not None and np.isrealobj(template): + self.real_transform = True + # For numpy functions. + # TODO Issue warning if user wants ifft(fft(data)) = N*data ? + if normalize != "ortho": + self.normalize = None + self.set_fft_functions() + #~ self.allocate_arrays() # not needed for this backend + self.compute_plans() + + + def set_fft_functions(self): + # (fwd, inv) = _fft_functions[is_real][ndim] + self._fft_functions = { + True: { + 1: (np.fft.rfft, np.fft.irfft), + 2: (np.fft.rfft2, np.fft.irfft2), + 3: (np.fft.rfftn, np.fft.irfftn), + }, + False: { + 1: (np.fft.fft, np.fft.ifft), + 2: (np.fft.fft2, np.fft.ifft2), + 3: (np.fft.fftn, np.fft.ifftn), + } + } + + + def _allocate(self, shape, dtype): + return np.zeros(self.queue, shape, dtype=dtype) + + + def compute_plans(self): + ndim = len(self.shape) + funcs = self._fft_functions[self.real_transform][np.minimum(ndim, 3)] + if np.version.version[:4] in ["1.8.", "1.9."]: + # norm keyword was introduced in 1.10 and we support numpy >= 1.8 + self.numpy_args = {} + else: + self.numpy_args = {"norm": self.normalize} + # Batched transform + if (self.user_axes is not None) and len(self.user_axes) < ndim: + funcs = self._fft_functions[self.real_transform][np.minimum(ndim-1, 3)] + self.numpy_args["axes"] = self.user_axes + # Special case of batched 1D transform on 2D data + if ndim == 2: + assert len(self.user_axes) == 1 + self.numpy_args["axis"] = self.user_axes[0] + self.numpy_args.pop("axes") + self.numpy_funcs = funcs + + + def fft(self, array): + """ + Perform a (forward) Fast Fourier Transform. + + :param numpy.ndarray array: + Input data. Must be consistent with the current context. + """ + return self.numpy_funcs[0](array, **self.numpy_args) + + + def ifft(self, array): + """ + Perform a (inverse) Fast Fourier Transform. + + :param numpy.ndarray array: + Input data. Must be consistent with the current context. + """ + return self.numpy_funcs[1](array, **self.numpy_args) + diff --git a/silx/math/fft/setup.py b/silx/math/fft/setup.py new file mode 100644 index 0000000..76bb864 --- /dev/null +++ b/silx/math/fft/setup.py @@ -0,0 +1,41 @@ +# coding: utf-8 +# /*########################################################################## +# Copyright (C) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ + +__authors__ = ["P. Naudet"] +__license__ = "MIT" +__date__ = "12/12/2018" + +import numpy +from numpy.distutils.misc_util import Configuration + + +def configuration(parent_package='', top_path=None): + config = Configuration('fft', parent_package, top_path) + config.add_subpackage('test') + return config + + +if __name__ == "__main__": + from numpy.distutils.core import setup + setup(configuration=configuration) diff --git a/silx/math/fft/test/__init__.py b/silx/math/fft/test/__init__.py new file mode 100644 index 0000000..8019fba --- /dev/null +++ b/silx/math/fft/test/__init__.py @@ -0,0 +1,28 @@ +# coding: utf-8 +# /*########################################################################## +# Copyright (C) 2016-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ + +from .test_fft import test_all + +def suite(): + return test_all() diff --git a/silx/math/fft/test/test_fft.py b/silx/math/fft/test/test_fft.py new file mode 100644 index 0000000..b0e595b --- /dev/null +++ b/silx/math/fft/test/test_fft.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Test of the MFFT module""" + +import numpy as np +import unittest +import logging +from scipy.misc import ascent +from silx.utils.testutils import parameterize +from silx.math.fft.fft import FFT +from silx.math.fft.clfft import __have_clfft__ +from silx.math.fft.cufft import __have_cufft__ +from silx.math.fft.fftw import __have_fftw__ + +from silx.test.utils import test_options +logger = logging.getLogger(__name__) + +class TransformInfos(object): + def __init__(self): + self.dimensions = [ + "1D", + "batched_1D", + "2D", + "batched_2D", + "3D", + ] + self.modes = { + "R2C": np.float32, + "R2C_double": np.float64, + "C2C": np.complex64, + "C2C_double": np.complex128, + } + self.sizes = { + "1D": [(512,), (511,)], + "2D": [(512, 512), (512, 511), (511, 512), (511, 511)], + "3D": [(128, 128, 128), (128, 128, 127), (128, 127, 128), (127, 128, 128), + (128, 127, 127), (127, 128, 127), (127, 127, 128), (127, 127, 127)] + } + self.axes = { + "1D": None, + "batched_1D": (-1,), + "2D": None, + "batched_2D": (-2, -1), + "3D": None, + } + self.sizes["batched_1D"] = self.sizes["2D"] + self.sizes["batched_2D"] = self.sizes["3D"] + + +class TestData(object): + def __init__(self): + self.data = ascent().astype("float32") + self.data1d = self.data[:, 0] # non-contiguous data + self.data3d = np.tile(self.data[:128, :128], (128, 1, 1)) + self.data_refs = { + 1: self.data1d, + 2: self.data, + 3: self.data3d, + } + + + +class TestFFT(unittest.TestCase): + + @classmethod + def setUpClass(cls): + super(TestFFT, cls).setUpClass() + cls.Ctx = None + if __have_clfft__: + from silx.opencl.common import ocl + if ocl is not None: + cls.Ctx = ocl.create_context() + + @classmethod + def tearDownClass(cls): + super(TestFFT, cls).tearDownClass() + if cls.Ctx is not None: + del cls.Ctx + + + def __init__(self, methodName='runTest', param=None): + unittest.TestCase.__init__(self, methodName) + self.param = param + + + def setUp(self): + self.tol = { + np.dtype("float32"): 1e-3, + np.dtype("float64"): 1e-9, + np.dtype("complex64"): 1e-3, + np.dtype("complex128"): 1e-9, + } + self.backend = self.param["backend"] + self.trdim = self.param["trdim"] + self.mode = self.param["mode"] + self.size = self.param["size"] + self.transform_infos = self.param["transform_infos"] + self.test_data = self.param["test_data"] + self.configure_backends() + self.configure_extra_args() + if self.backend == "opencl" and self.Ctx is None: + self.skipTest("PyopenCL is missing") + + + def tearDown(self): + pass + + + def configure_backends(self): + self.__have_clfft__ = __have_clfft__ + self.__have_cufft__ = __have_cufft__ + self.__have_fftw__ = __have_fftw__ + + if self.backend in ["cuda", "cufft"] and __have_cufft__: + import pycuda.autoinit + # Error is higher when using cuda. fast_math mode ? + self.tol[np.dtype("float32")] *= 2 + + + def configure_extra_args(self): + self.extra_args = {} + if __have_clfft__ and self.backend in ["opencl", "clfft"]: + self.extra_args["ctx"] = self.Ctx + + + def check_current_backend(self): + if self.backend in ["cuda", "cufft"] and not(self.__have_cufft__): + return "cuda back-end requires pycuda and scikit-cuda" + if self.backend in ["opencl", "clfft"] and not(self.__have_clfft__): + return "opencl back-end requires pyopencl and gpyfft" + if self.backend == "fftw" and not(self.__have_fftw__): + return "fftw back-end requires pyfftw" + return None + + + @staticmethod + def calc_mae(arr1, arr2): + """ + Compute the Max Absolute Error between two arrays + """ + return np.max(np.abs(arr1 - arr2)) + + + def test_fft(self): + err = self.check_current_backend() + if err is not None: + self.skipTest(err) + if self.size == "3D" and test_options.TEST_LOW_MEM: + self.skipTest("low mem") + + ndim = len(self.size) + input_data = self.test_data.data_refs[ndim].astype(self.transform_infos.modes[self.mode]) + tol = self.tol[np.dtype(input_data.dtype)] + if self.trdim == "3D": + tol *= 10 # Error is relatively high in high dimensions + + # Python < 3.5 does not want to mix **extra_args with existing kwargs + fft_args = { + "template": input_data, + "axes": self.transform_infos.axes[self.trdim], + "backend": self.backend, + } + fft_args.update(self.extra_args) + F = FFT( + **fft_args + ) + F_np = FFT( + template=input_data, + axes=self.transform_infos.axes[self.trdim], + backend="numpy" + ) + + # Forward FFT + res = F.fft(input_data) + res_np = F_np.fft(input_data) + mae = self.calc_mae(res, res_np) + self.assertTrue( + mae < np.abs(input_data.max()) * tol, + "FFT %s:%s, MAE(%s, numpy) = %f" % (self.mode, self.trdim, self.backend, mae) + ) + + # Inverse FFT + res2 = F.ifft(res) + mae = self.calc_mae(res2, input_data) + self.assertTrue( + mae < tol, + "IFFT %s:%s, MAE(%s, numpy) = %f" % (self.mode, self.trdim, self.backend, mae) + ) + + +class TestNumpyFFT(unittest.TestCase): + """ + Test the Numpy backend individually. + """ + + def __init__(self, methodName='runTest', param=None): + unittest.TestCase.__init__(self, methodName) + self.param = param + + def setUp(self): + transforms = { + "1D": { + True: (np.fft.rfft, np.fft.irfft), + False: (np.fft.fft, np.fft.ifft), + }, + "2D": { + True: (np.fft.rfft2, np.fft.irfft2), + False: (np.fft.fft2, np.fft.ifft2), + }, + "3D": { + True: (np.fft.rfftn, np.fft.irfftn), + False: (np.fft.fftn, np.fft.ifftn), + }, + } + transforms["batched_1D"] = transforms["1D"] + transforms["batched_2D"] = transforms["2D"] + self.transforms = transforms + + + def test_numpy_fft(self): + """ + Test the numpy backend against native fft. + Results should be exactly the same. + """ + trinfos = self.param["transform_infos"] + trdim = self.param["trdim"] + ndim = len(self.param["size"]) + input_data = self.param["test_data"].data_refs[ndim].astype(trinfos.modes[self.param["mode"]]) + np_fft, np_ifft = self.transforms[trdim][np.isrealobj(input_data)] + + F = FFT( + template=input_data, + axes=trinfos.axes[trdim], + backend="numpy" + ) + # Test FFT + res = F.fft(input_data) + ref = np_fft(input_data) + self.assertTrue(np.allclose(res, ref)) + + # Test IFFT + res2 = F.ifft(res) + ref2 = np_ifft(ref) + self.assertTrue(np.allclose(res2, ref2)) + + +def test_numpy_backend(dimensions=None): + testSuite = unittest.TestSuite() + transform_infos = TransformInfos() + test_data = TestData() + dimensions = dimensions or transform_infos.dimensions + + for trdim in dimensions: + logger.debug(" testing %s" % trdim) + for mode in transform_infos.modes: + logger.debug(" testing %s:%s" % (trdim, mode)) + for size in transform_infos.sizes[trdim]: + logger.debug(" size: %s" % str(size)) + testcase = parameterize( + TestNumpyFFT, + param={ + "transform_infos": transform_infos, + "test_data": test_data, + "trdim": trdim, + "mode": mode, + "size": size, + } + ) + testSuite.addTest(testcase) + return testSuite + + +def test_fft(backend, dimensions=None): + testSuite = unittest.TestSuite() + transform_infos = TransformInfos() + test_data = TestData() + dimensions = dimensions or transform_infos.dimensions + + logger.info("Testing backend: %s" % backend) + for trdim in dimensions: + logger.debug(" testing %s" % trdim) + for mode in transform_infos.modes: + logger.debug(" testing %s:%s" % (trdim, mode)) + for size in transform_infos.sizes[trdim]: + logger.debug(" size: %s" % str(size)) + testcase = parameterize( + TestFFT, + param={ + "transform_infos": transform_infos, + "test_data": test_data, + "backend": backend, + "trdim": trdim, + "mode": mode, + "size": size, + } + ) + testSuite.addTest(testcase) + return testSuite + + +def test_all(): + suite = unittest.TestSuite() + + suite.addTest(test_numpy_backend()) + + suite.addTest(test_fft("fftw")) + suite.addTest(test_fft("opencl")) + suite.addTest(test_fft("cuda")) + return suite + + +if __name__ == '__main__': + unittest.main(defaultTest="test_all") + + |