summaryrefslogtreecommitdiff
path: root/src/silx/math/fft
diff options
context:
space:
mode:
Diffstat (limited to 'src/silx/math/fft')
-rw-r--r--src/silx/math/fft/__init__.py7
-rw-r--r--src/silx/math/fft/basefft.py150
-rw-r--r--src/silx/math/fft/clfft.py278
-rw-r--r--src/silx/math/fft/cufft.py256
-rw-r--r--src/silx/math/fft/fft.py99
-rw-r--r--src/silx/math/fft/fftw.py376
-rw-r--r--src/silx/math/fft/npfft.py144
-rw-r--r--src/silx/math/fft/test/__init__.py22
-rw-r--r--src/silx/math/fft/test/test_fft.py423
9 files changed, 1755 insertions, 0 deletions
diff --git a/src/silx/math/fft/__init__.py b/src/silx/math/fft/__init__.py
new file mode 100644
index 0000000..6966a60
--- /dev/null
+++ b/src/silx/math/fft/__init__.py
@@ -0,0 +1,7 @@
+#!/usr/bin/env python
+
+__authors__ = ["P. Paleo"]
+__license__ = "MIT"
+__date__ = "12/12/2018"
+
+from .fft import FFT
diff --git a/src/silx/math/fft/basefft.py b/src/silx/math/fft/basefft.py
new file mode 100644
index 0000000..6e9fac8
--- /dev/null
+++ b/src/silx/math/fft/basefft.py
@@ -0,0 +1,150 @@
+#!/usr/bin/env python
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+import numpy as np
+from packaging.version import Version
+
+
+def check_version(package, required_version):
+ """
+ Check whether a given package version is superior or equal to required_version.
+ """
+ try:
+ ver = getattr(package, "__version__")
+ except AttributeError:
+ try:
+ ver = getattr(package, "version")
+ except Exception:
+ return False
+ req_v = Version(required_version)
+ ver_v = Version(ver)
+ return ver_v >= req_v
+
+
+class BaseFFT(object):
+ """
+ Base class for all FFT backends.
+ """
+
+ def __init__(self, **kwargs):
+ self.__get_args(**kwargs)
+
+ if self.shape is None and self.dtype is None and self.template is None:
+ raise ValueError("Please provide either (shape and dtype) or template")
+ if self.template is not None:
+ self.shape = self.template.shape
+ self.dtype = self.template.dtype
+ self.user_data = self.template
+ self.data_allocated = False
+ self.__calc_axes()
+ self.__set_dtypes()
+ self.__calc_shape()
+
+ def __get_args(self, **kwargs):
+ expected_args = {
+ "shape": None,
+ "dtype": None,
+ "template": None,
+ "shape_out": None,
+ "axes": None,
+ "normalize": "rescale",
+ }
+ for arg_name, default_val in expected_args.items():
+ if arg_name not in kwargs:
+ # Base class was not instantiated properly
+ raise ValueError("Please provide argument %s" % arg_name)
+ setattr(self, arg_name, default_val)
+ for arg_name, arg_val in kwargs.items():
+ setattr(self, arg_name, arg_val)
+
+ def __set_dtypes(self):
+ dtypes_mapping = {
+ np.dtype("float32"): np.complex64,
+ np.dtype("float64"): np.complex128,
+ np.dtype("complex64"): np.complex64,
+ np.dtype("complex128"): np.complex128,
+ }
+ dp = {np.dtype("float32"): np.float64, np.dtype("complex64"): np.complex128}
+ self.dtype_in = np.dtype(self.dtype)
+ if self.dtype_in not in dtypes_mapping:
+ raise ValueError("Invalid input data type: got %s" % self.dtype_in)
+ self.dtype_out = dtypes_mapping[self.dtype_in]
+
+ def __calc_shape(self):
+ # TODO allow for C2C even for real input data (?)
+ if self.dtype_in in [np.float32, np.float64]:
+ last_dim = self.shape[-1] // 2 + 1
+ # FFTW convention
+ self.shape_out = self.shape[:-1] + (self.shape[-1] // 2 + 1,)
+ else:
+ self.shape_out = self.shape
+
+ def __calc_axes(self):
+ default_axes = tuple(range(len(self.shape)))
+ if self.axes is None:
+ self.axes = default_axes
+ self.user_axes = None
+ else:
+ self.user_axes = self.axes
+ # Handle possibly negative axes
+ self.axes = tuple(np.array(default_axes)[np.array(self.user_axes)])
+
+ def _allocate(self, shape, dtype):
+ raise ValueError("This should be implemented by back-end FFT")
+
+ def set_data(self, dst, src, shape, dtype, copy=True):
+ raise ValueError("This should be implemented by back-end FFT")
+
+ def allocate_arrays(self):
+ if not (self.data_allocated):
+ self.data_in = self._allocate(self.shape, self.dtype_in)
+ self.data_out = self._allocate(self.shape_out, self.dtype_out)
+ self.data_allocated = True
+
+ def set_input_data(self, data, copy=True):
+ if data is None:
+ return self.data_in
+ else:
+ return self.set_data(
+ self.data_in, data, self.shape, self.dtype_in, copy=copy, name="data_in"
+ )
+
+ def set_output_data(self, data, copy=True):
+ if data is None:
+ return self.data_out
+ else:
+ return self.set_data(
+ self.data_out,
+ data,
+ self.shape_out,
+ self.dtype_out,
+ copy=copy,
+ name="data_out",
+ )
+
+ def fft(self, array, **kwargs):
+ raise ValueError("This should be implemented by back-end FFT")
+
+ def ifft(self, array, **kwargs):
+ raise ValueError("This should be implemented by back-end FFT")
diff --git a/src/silx/math/fft/clfft.py b/src/silx/math/fft/clfft.py
new file mode 100644
index 0000000..488102a
--- /dev/null
+++ b/src/silx/math/fft/clfft.py
@@ -0,0 +1,278 @@
+#!/usr/bin/env python
+# /*##########################################################################
+#
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+import numpy as np
+
+from .basefft import BaseFFT, check_version
+
+try:
+ import pyopencl as cl
+ import pyopencl.array as parray
+ import gpyfft
+ from gpyfft.fft import FFT as cl_fft
+ from ...opencl.common import ocl
+
+ __have_clfft__ = True
+except ImportError:
+ __have_clfft__ = False
+
+
+# Check gpyfft version
+__required_gpyfft_version__ = "0.3.0"
+if __have_clfft__:
+ __have_clfft__ = check_version(gpyfft, __required_gpyfft_version__)
+
+
+class CLFFT(BaseFFT):
+ """Initialize a clfft plan.
+
+ Please see FFT class for parameters help.
+
+ CLFFT-specific parameters
+ --------------------------
+
+ :param pyopencl.Context ctx:
+ If set to other than None, an existing pyopencl context is used.
+ :param bool fast_math:
+ If set to True, computations will be done with "fast math" mode,
+ i.e., more speed but less accuracy.
+ :param bool choose_best_device:
+ Whether to automatically choose the best available OpenCL device.
+ """
+
+ def __init__(
+ self,
+ shape=None,
+ dtype=None,
+ template=None,
+ shape_out=None,
+ axes=None,
+ normalize="rescale",
+ ctx=None,
+ fast_math=False,
+ choose_best_device=True,
+ ):
+ if not (__have_clfft__) or not (__have_clfft__):
+ raise ImportError(
+ "Please install pyopencl and gpyfft >= %s to use the OpenCL back-end"
+ % __required_gpyfft_version__
+ )
+
+ super().__init__(
+ shape=shape,
+ dtype=dtype,
+ template=template,
+ shape_out=shape_out,
+ axes=axes,
+ normalize=normalize,
+ )
+ self.ctx = ctx
+ self.choose_best_device = choose_best_device
+ self.fast_math = fast_math
+ self.backend = "clfft"
+
+ self.fix_axes()
+ self.init_context_queue()
+ self.allocate_arrays()
+ self.real_transform = np.isrealobj(self.data_in)
+ self.compute_forward_plan()
+ self.compute_inverse_plan()
+ self.refs = {
+ "data_in": self.data_in,
+ "data_out": self.data_out,
+ }
+ # TODO
+ # Either pyopencl ElementWiseKernel, or built-in clfft callbacks
+ if self.normalize != "rescale":
+ raise NotImplementedError(
+ "Normalization modes other than rescale are not implemented with OpenCL backend yet."
+ )
+
+ def fix_axes(self):
+ """
+ "Fix" axes.
+
+ clfft does not have the same convention as FFTW/cuda/numpy.
+ """
+ self.axes = self.axes[::-1]
+
+ def _allocate(self, shape, dtype):
+ ary = parray.empty(self.queue, shape, dtype=dtype)
+ ary.fill(0)
+ return ary
+
+ def check_array(self, array, shape, dtype, copy=True):
+ if array.shape != shape:
+ raise ValueError(
+ "Invalid data shape: expected %s, got %s" % (shape, array.shape)
+ )
+ if array.dtype != dtype:
+ raise ValueError(
+ "Invalid data type: expected %s, got %s" % (dtype, array.dtype)
+ )
+
+ def set_data(self, dst, src, shape, dtype, copy=True, name=None):
+ """
+ dst is a device array owned by the current instance
+ (either self.data_in or self.data_out).
+
+ copy is ignored for device<-> arrays.
+ """
+ self.check_array(src, shape, dtype)
+ if isinstance(src, np.ndarray):
+ if name == "data_out":
+ # Makes little sense to provide output=numpy_array
+ return dst
+ if not (src.flags["C_CONTIGUOUS"]):
+ src = np.ascontiguousarray(src, dtype=dtype)
+ # working on underlying buffer is notably faster
+ # ~ dst[:] = src[:]
+ evt = cl.enqueue_copy(self.queue, dst.data, src)
+ evt.wait()
+ elif isinstance(src, parray.Array):
+ # No copy, use the data as self.d_input or self.d_output
+ # (this prevents the use of in-place transforms, however).
+ # We have to keep their old references.
+ if name is None:
+ # This should not happen
+ raise ValueError("Please provide either copy=True or name != None")
+ assert id(self.refs[name]) == id(dst) # DEBUG
+ setattr(self, name, src)
+ return src
+ else:
+ raise ValueError(
+ "Invalid array type %s, expected numpy.ndarray or pyopencl.array"
+ % type(src)
+ )
+ return dst
+
+ def recover_array_references(self):
+ self.data_in = self.refs["data_in"]
+ self.data_out = self.refs["data_out"]
+
+ def init_context_queue(self):
+ if self.ctx is None:
+ if self.choose_best_device:
+ self.ctx = ocl.create_context()
+ else:
+ self.ctx = cl.create_some_context()
+ self.queue = cl.CommandQueue(self.ctx)
+
+ def compute_forward_plan(self):
+ self.plan_forward = cl_fft(
+ self.ctx,
+ self.queue,
+ self.data_in,
+ out_array=self.data_out,
+ axes=self.axes,
+ fast_math=self.fast_math,
+ real=self.real_transform,
+ )
+
+ def compute_inverse_plan(self):
+ self.plan_inverse = cl_fft(
+ self.ctx,
+ self.queue,
+ self.data_out,
+ out_array=self.data_in,
+ axes=self.axes,
+ fast_math=self.fast_math,
+ real=self.real_transform,
+ )
+
+ def update_forward_plan_arrays(self):
+ self.plan_forward.data = self.data_in
+ self.plan_forward.result = self.data_out
+
+ def update_inverse_plan_arrays(self):
+ self.plan_inverse.data = self.data_out
+ self.plan_inverse.result = self.data_in
+
+ def copy_output_if_numpy(self, dst, src):
+ if isinstance(dst, parray.Array):
+ return
+ # working on underlying buffer is notably faster
+ # ~ dst[:] = src[:]
+ evt = cl.enqueue_copy(self.queue, dst, src.data)
+ evt.wait()
+
+ def fft(self, array, output=None, do_async=False):
+ """
+ Perform a (forward) Fast Fourier Transform.
+
+ :param Union[numpy.ndarray,pyopencl.array] array:
+ Input data. Must be consistent with the current context.
+ :param Union[numpy.ndarray,pyopencl.array] output:
+ Output data. By default, output is a numpy.ndarray.
+ :param bool do_async:
+ Whether to perform operation in asynchronous mode.
+ Default is False, meaning that we wait for transform to complete.
+ """
+ self.set_input_data(array, copy=False)
+ self.set_output_data(output, copy=False)
+ self.update_forward_plan_arrays()
+ (event,) = self.plan_forward.enqueue()
+ if not (do_async):
+ event.wait()
+ if output is not None:
+ self.copy_output_if_numpy(output, self.data_out)
+ res = output
+ else:
+ res = self.data_out.get()
+ self.recover_array_references()
+ return res
+
+ def ifft(self, array, output=None, do_async=False):
+ """
+ Perform a (inverse) Fast Fourier Transform.
+
+ :param Union[numpy.ndarray,pyopencl.array] array:
+ Input data. Must be consistent with the current context.
+ :param Union[numpy.ndarray,pyopencl.array] output:
+ Output data. By default, output is a numpy.ndarray.
+ :param bool do_async:
+ Whether to perform operation in asynchronous mode.
+ Default is False, meaning that we wait for transform to complete.
+ """
+ self.set_output_data(array, copy=False)
+ self.set_input_data(output, copy=False)
+ self.update_inverse_plan_arrays()
+ (event,) = self.plan_inverse.enqueue(forward=False)
+ if not (do_async):
+ event.wait()
+ if output is not None:
+ self.copy_output_if_numpy(output, self.data_in)
+ res = output
+ else:
+ res = self.data_in.get()
+ self.recover_array_references()
+ return res
+
+ def __del__(self):
+ # It seems that gpyfft underlying clFFT destructors are not called.
+ # This results in the following warning:
+ # Warning: Program terminating, but clFFT resources not freed.
+ # Please consider explicitly calling clfftTeardown( )
+ del self.plan_forward
+ del self.plan_inverse
diff --git a/src/silx/math/fft/cufft.py b/src/silx/math/fft/cufft.py
new file mode 100644
index 0000000..c609439
--- /dev/null
+++ b/src/silx/math/fft/cufft.py
@@ -0,0 +1,256 @@
+#!/usr/bin/env python
+# /*##########################################################################
+#
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+import numpy as np
+
+from .basefft import BaseFFT
+
+try:
+ import pycuda.gpuarray as gpuarray
+ from skcuda.fft import Plan
+ from skcuda.fft import fft as cu_fft
+ from skcuda.fft import ifft as cu_ifft
+
+ __have_cufft__ = True
+except ImportError:
+ __have_cufft__ = False
+
+
+class CUFFT(BaseFFT):
+ """Initialize a cufft plan
+
+ Please see FFT class for parameters help.
+
+ CUFFT-specific parameters
+ --------------------------
+
+ :param pycuda.driver.Stream stream:
+ Stream with which to associate the plan. If no stream is specified,
+ the default stream is used.
+ """
+
+ def __init__(
+ self,
+ shape=None,
+ dtype=None,
+ template=None,
+ shape_out=None,
+ axes=None,
+ normalize="rescale",
+ stream=None,
+ ):
+ if not (__have_cufft__) or not (__have_cufft__):
+ raise ImportError(
+ "Please install pycuda and scikit-cuda to use the CUDA back-end"
+ )
+
+ super().__init__(
+ shape=shape,
+ dtype=dtype,
+ template=template,
+ shape_out=shape_out,
+ axes=axes,
+ normalize=normalize,
+ )
+ self.cufft_stream = stream
+ self.backend = "cufft"
+
+ self.configure_batched_transform()
+ self.allocate_arrays()
+ self.real_transform = np.isrealobj(self.data_in)
+ self.compute_forward_plan()
+ self.compute_inverse_plan()
+ self.refs = {
+ "data_in": self.data_in,
+ "data_out": self.data_out,
+ }
+ self.configure_normalization()
+
+ def _allocate(self, shape, dtype):
+ return gpuarray.zeros(shape, dtype)
+
+ # TODO support batched transform where batch is other than dimension 0
+ def configure_batched_transform(self):
+ self.cufft_batch_size = 1
+ self.cufft_shape = self.shape
+ if (self.axes is not None) and (len(self.axes) < len(self.shape)):
+ # In the easiest case, the transform is computed along the fastest dimensions:
+ # - 1D transforms of lines of 2D data
+ # - 2D transforms of images of 3D data (stacked along slow dim)
+ # - 1D transforms of 3D data along fastest dim
+ # Otherwise, we have to configure cuda "advanced memory layout",
+ # which is not implemented yet.
+
+ data_ndims = len(self.shape)
+ supported_axes = {
+ 2: [(1,)],
+ 3: [(1, 2), (2, 1), (1,), (2,)],
+ }
+ if self.axes not in supported_axes[data_ndims]:
+ raise NotImplementedError(
+ "With the CUDA backend, batched transform is only supported along fastest dimensions"
+ )
+ self.cufft_batch_size = self.shape[0]
+ self.cufft_shape = self.shape[1:]
+ if data_ndims == 3 and len(self.axes) == 1:
+ # 1D transform on 3D data: here only supported along fast dim,
+ # so batch_size is Nx*Ny
+ self.cufft_batch_size = np.prod(self.shape[:2])
+ self.cufft_shape = (self.shape[-1],)
+ if len(self.cufft_shape) == 1:
+ self.cufft_shape = self.cufft_shape[0]
+
+ def configure_normalization(self):
+ # TODO
+ if self.normalize == "ortho":
+ raise NotImplementedError(
+ "Normalization mode 'ortho' is not implemented with CUDA backend yet."
+ )
+ self.cufft_scale_inverse = self.normalize == "rescale"
+
+ def check_array(self, array, shape, dtype, copy=True):
+ if array.shape != shape:
+ raise ValueError(
+ "Invalid data shape: expected %s, got %s" % (shape, array.shape)
+ )
+ if array.dtype != dtype:
+ raise ValueError(
+ "Invalid data type: expected %s, got %s" % (dtype, array.dtype)
+ )
+
+ def set_data(self, dst, src, shape, dtype, copy=True, name=None):
+ """
+ dst is a device array owned by the current instance
+ (either self.data_in or self.data_out).
+
+ copy is ignored for device<-> arrays.
+ """
+ self.check_array(src, shape, dtype)
+ if isinstance(src, np.ndarray):
+ if name == "data_out":
+ # Makes little sense to provide output=numpy_array
+ return dst
+ if not (src.flags["C_CONTIGUOUS"]):
+ src = np.ascontiguousarray(src, dtype=dtype)
+ dst[:] = src[:]
+ elif isinstance(src, gpuarray.GPUArray):
+ # No copy, use the data as self.d_input or self.d_output
+ # (this prevents the use of in-place transforms, however).
+ # We have to keep their old references.
+ if name is None:
+ # This should not happen
+ raise ValueError("Please provide either copy=True or name != None")
+ assert id(self.refs[name]) == id(dst) # DEBUG
+ setattr(self, name, src)
+ return src
+ else:
+ raise ValueError(
+ "Invalid array type %s, expected numpy.ndarray or pycuda.gpuarray"
+ % type(src)
+ )
+ return dst
+
+ def recover_array_references(self):
+ self.data_in = self.refs["data_in"]
+ self.data_out = self.refs["data_out"]
+
+ def compute_forward_plan(self):
+ self.plan_forward = Plan(
+ self.cufft_shape,
+ self.dtype,
+ self.dtype_out,
+ batch=self.cufft_batch_size,
+ stream=self.cufft_stream,
+ # cufft extensible plan API is only supported after 0.5.1
+ # (commit 65288d28ca0b93e1234133f8d460dc6becb65121)
+ # but there is still no official 0.5.2
+ # ~ auto_allocate=True # cufft extensible plan API
+ )
+
+ def compute_inverse_plan(self):
+ self.plan_inverse = Plan(
+ self.cufft_shape, # not shape_out
+ self.dtype_out,
+ self.dtype,
+ batch=self.cufft_batch_size,
+ stream=self.cufft_stream,
+ # cufft extensible plan API is only supported after 0.5.1
+ # (commit 65288d28ca0b93e1234133f8d460dc6becb65121)
+ # but there is still no official 0.5.2
+ # ~ auto_allocate=True
+ )
+
+ def copy_output_if_numpy(self, dst, src):
+ if isinstance(dst, gpuarray.GPUArray):
+ return
+ dst[:] = src[:]
+
+ def fft(self, array, output=None):
+ """
+ Perform a (forward) Fast Fourier Transform.
+
+ :param Union[numpy.ndarray,pycuda.gpuarray] array:
+ Input data. Must be consistent with the current context.
+ :param Union[numpy.ndarray,pycuda.gpuarray] output:
+ Output data. By default, output is a numpy.ndarray.
+ """
+ data_in = self.set_input_data(array, copy=False)
+ data_out = self.set_output_data(output, copy=False)
+
+ cu_fft(data_in, data_out, self.plan_forward, scale=False)
+
+ if output is not None:
+ self.copy_output_if_numpy(output, self.data_out)
+ res = output
+ else:
+ res = self.data_out.get()
+ self.recover_array_references()
+ return res
+
+ def ifft(self, array, output=None):
+ """
+ Perform a (inverse) Fast Fourier Transform.
+
+ :param Union[numpy.ndarray,pycuda.gpuarray] array:
+ Input data. Must be consistent with the current context.
+ :param Union[numpy.ndarray,pycuda.gpuarray] output:
+ Output data. By default, output is a numpy.ndarray.
+ """
+ data_in = self.set_output_data(array, copy=False)
+ data_out = self.set_input_data(output, copy=False)
+
+ cu_ifft(
+ data_in,
+ data_out,
+ self.plan_inverse,
+ scale=self.cufft_scale_inverse,
+ )
+
+ if output is not None:
+ self.copy_output_if_numpy(output, self.data_in)
+ res = output
+ else:
+ res = self.data_in.get()
+ self.recover_array_references()
+ return res
diff --git a/src/silx/math/fft/fft.py b/src/silx/math/fft/fft.py
new file mode 100644
index 0000000..7daf17b
--- /dev/null
+++ b/src/silx/math/fft/fft.py
@@ -0,0 +1,99 @@
+#!/usr/bin/env python
+# /*##########################################################################
+#
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+from .fftw import FFTW
+from .npfft import NPFFT
+
+
+def FFT(
+ shape=None,
+ dtype=None,
+ template=None,
+ shape_out=None,
+ axes=None,
+ normalize="rescale",
+ backend="numpy",
+ **kwargs,
+):
+ """
+ Initialize a FFT plan.
+
+ :param List[int] shape:
+ Shape of the input data.
+ :param numpy.dtype dtype:
+ Data type of the input data.
+ :param numpy.ndarray template:
+ Optional data, replacement for "shape" and "dtype".
+ If provided, the arguments "shape" and "dtype" are ignored,
+ and are instead inferred from it.
+ :param List[int] shape_out:
+ Optional shape of output data.
+ By default, the data has the same shape as the input
+ data (in case of C2C transform), or a shape with the last dimension halved
+ (in case of R2C transform). If shape_out is provided, it must be greater
+ or equal than the shape of input data. In this case, FFT is performed
+ with zero-padding.
+ :param List[int] axes:
+ Axes along which FFT is computed.
+ * For 2D transform: axes=(1,0)
+ * For batched 1D transform of 2D image: axes=(0,)
+ :param str normalize:
+ Whether to normalize FFT and IFFT. Possible values are:
+ * "rescale": in this case, Fourier data is divided by "N"
+ before IFFT, so that IFFT(FFT(data)) = data.
+ This corresponds to numpy norm=None i.e norm="backward".
+ * "ortho": in this case, FFT and IFFT are adjoint of eachother,
+ the transform is unitary. Both FFT and IFFT are scaled with 1/sqrt(N).
+ * "none": no normalizatio is done : IFFT(FFT(data)) = data*N
+ :param str backend:
+ FFT Backend to use. Value can be "numpy", "fftw", "opencl", "cuda".
+ """
+ backends = ["numpy", "fftw", "opencl", "cuda"]
+ backend = backend.lower()
+ if backend in ["numpy", "np"]:
+ fft_cls = NPFFT
+ elif backend == "fftw":
+ fft_cls = FFTW
+ elif backend in ["opencl", "clfft"]:
+ # Late import for creating context only if needed
+ from .clfft import CLFFT
+
+ fft_cls = CLFFT
+ elif backend in ["cuda", "cufft"]:
+ # Late import for creating context only if needed
+ from .cufft import CUFFT
+
+ fft_cls = CUFFT
+ else:
+ raise ValueError("Unknown backend %s, available are %s" % (backend, backends))
+ F = fft_cls(
+ shape=shape,
+ dtype=dtype,
+ template=template,
+ shape_out=shape_out,
+ axes=axes,
+ normalize=normalize,
+ **kwargs,
+ )
+ return F
diff --git a/src/silx/math/fft/fftw.py b/src/silx/math/fft/fftw.py
new file mode 100644
index 0000000..69edbb6
--- /dev/null
+++ b/src/silx/math/fft/fftw.py
@@ -0,0 +1,376 @@
+#!/usr/bin/env python
+# /*##########################################################################
+#
+# Copyright (c) 2018-2022 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+import os
+from sys import executable as sys_executable
+from socket import gethostname
+from tempfile import gettempdir
+from pathlib import Path
+import numpy as np
+from .basefft import BaseFFT, check_version
+
+try:
+ import pyfftw
+
+ __have_fftw__ = True
+except ImportError:
+ __have_fftw__ = False
+
+
+# Check pyfftw version
+__required_pyfftw_version__ = "0.10.0"
+if __have_fftw__:
+ __have_fftw__ = check_version(pyfftw, __required_pyfftw_version__)
+
+
+class FFTW(BaseFFT):
+ """Initialize a FFTW plan.
+
+ Please see FFT class for parameters help.
+
+ FFTW-specific parameters
+ -------------------------
+
+ :param bool check_alignment:
+ If set to True and "data" is provided, this will enforce the input data
+ to be "byte aligned", which might imply extra memory usage.
+ :param int num_threads:
+ Number of threads for computing FFT.
+ """
+
+ def __init__(
+ self,
+ shape=None,
+ dtype=None,
+ template=None,
+ shape_out=None,
+ axes=None,
+ normalize="rescale",
+ check_alignment=False,
+ num_threads=1,
+ ):
+ if not (__have_fftw__):
+ raise ImportError(
+ "Please install pyfftw >= %s to use the FFTW back-end"
+ % __required_pyfftw_version__
+ )
+ super().__init__(
+ shape=shape,
+ dtype=dtype,
+ template=template,
+ shape_out=shape_out,
+ axes=axes,
+ normalize=normalize,
+ )
+ self.check_alignment = check_alignment
+ self.num_threads = num_threads
+ self.backend = "fftw"
+
+ self.allocate_arrays()
+ self.set_fftw_flags()
+ self.compute_forward_plan()
+ self.compute_inverse_plan()
+ self.refs = {
+ "data_in": self.data_in,
+ "data_out": self.data_out,
+ }
+
+ # About normalization with norm="none", issues about pyfftw version :
+ # --------------- pyfftw 0.12 ---------------
+ # FFT :
+ # normalise_idft --> 1
+ # not normalise_idft --> 1
+ # IFFT :
+ # normalise_idft --> 1 / N
+ # not normalise_idft --> 1
+ # --------------- pyfftw 0.13 ---------------
+ # FFT :
+ # normalise_idft --> 1
+ # not normalise_idft --> 1 / N (this normalization is incorrect, doc says contrary)
+ # IFFT :
+ # normalise_idft --> 1 / N
+ # not normalise_idft --> 1
+
+ # Solution :
+ # select 'normalise_idft' for FFT and 'not normalise_idft' for IFFT
+ # => behavior is the same in both version :)
+
+ def set_fftw_flags(self):
+ self.fftw_flags = ("FFTW_MEASURE",) # TODO
+ self.fftw_planning_timelimit = None # TODO
+
+ # To skip normalization on norm="none", we should
+ # flip 'normalise_idft' to normalize no-where (see comments up):
+ #
+ # and :
+ # ortho (orthogonal normalization)
+ # ortho = True : forward -> 1/sqrt(N), backward -> 1/sqrt(N)
+
+ self.fftw_norm_modes = {
+ "rescale": (
+ {"ortho": False, "normalise_idft": True}, # fft
+ {"ortho": False, "normalise_idft": True}, # ifft
+ ),
+ "ortho": (
+ {"ortho": True, "normalise_idft": False}, # fft
+ {"ortho": True, "normalise_idft": False}, # ifft
+ ),
+ "none": (
+ {"ortho": False, "normalise_idft": True}, # fft
+ {"ortho": False, "normalise_idft": False}, # ifft
+ ),
+ }
+ if self.normalize not in self.fftw_norm_modes:
+ raise ValueError(
+ "Unknown normalization mode %s. Possible values are %s"
+ % (self.normalize, self.fftw_norm_modes.keys())
+ )
+ self.fftw_norm_mode = self.fftw_norm_modes[self.normalize]
+
+ def _allocate(self, shape, dtype):
+ return pyfftw.zeros_aligned(shape, dtype=dtype)
+
+ def check_array(self, array, shape, dtype, copy=True):
+ if array.shape != shape:
+ raise ValueError(
+ "Invalid data shape: expected %s, got %s" % (shape, array.shape)
+ )
+ if array.dtype != dtype:
+ raise ValueError(
+ "Invalid data type: expected %s, got %s" % (dtype, array.dtype)
+ )
+
+ def set_data(self, self_array, array, shape, dtype, copy=True, name=None):
+ """
+ :param self_array: array owned by the current instance
+ (either self.data_in or self.data_out).
+ :type: numpy.ndarray
+ :param self_array: data to set
+ :type: numpy.ndarray
+ :type tuple shape: shape of the array
+ :param dtype: type of the array
+ :type: numpy.dtype
+ :param bool copy: should we copy the array
+ :param str name: name of the array
+
+ Copies are avoided when possible.
+ """
+ self.check_array(array, shape, dtype)
+ if id(self.refs[name]) == id(array):
+ # nothing to do: fft is performed on self.data_in or self.data_out
+ arr_to_use = self.refs[name]
+ if self.check_alignment and not (pyfftw.is_byte_aligned(array)):
+ # If the array is not properly aligned,
+ # create a temp. array copy it to self.data_in or self.data_out
+ self_array[:] = array[:]
+ arr_to_use = self_array
+ else:
+ # If the array is properly aligned, use it directly
+ if copy:
+ arr_to_use = np.copy(array)
+ else:
+ arr_to_use = array
+ return arr_to_use
+
+ def compute_forward_plan(self):
+ self.plan_forward = pyfftw.FFTW(
+ self.data_in,
+ self.data_out,
+ axes=self.axes,
+ direction="FFTW_FORWARD",
+ flags=self.fftw_flags,
+ threads=self.num_threads,
+ planning_timelimit=self.fftw_planning_timelimit,
+ )
+
+ def compute_inverse_plan(self):
+ self.plan_inverse = pyfftw.FFTW(
+ self.data_out,
+ self.data_in,
+ axes=self.axes,
+ direction="FFTW_BACKWARD",
+ flags=self.fftw_flags,
+ threads=self.num_threads,
+ planning_timelimit=self.fftw_planning_timelimit,
+ )
+
+ def fft(self, array, output=None):
+ """
+ Perform a (forward) Fast Fourier Transform.
+
+ :param numpy.ndarray array:
+ Input data. Must be consistent with the current context.
+ :param numpy.ndarray output:
+ Optional output data.
+ """
+ data_in = self.set_input_data(array, copy=False)
+ data_out = self.set_output_data(output, copy=False)
+ self.plan_forward.update_arrays(data_in, data_out)
+ # execute.__call__ does both update_arrays() and normalization
+ self.plan_forward( # [0] --> fft
+ ortho=self.fftw_norm_mode[0]["ortho"],
+ normalise_idft=self.fftw_norm_mode[0]["normalise_idft"],
+ )
+ self.plan_forward.update_arrays(self.refs["data_in"], self.refs["data_out"])
+ return data_out
+
+ def ifft(self, array, output=None):
+ """
+ Perform a (inverse) Fast Fourier Transform.
+
+ :param numpy.ndarray array:
+ Input data. Must be consistent with the current context.
+ :param numpy.ndarray output:
+ Optional output data.
+ """
+ data_in = self.set_output_data(array, copy=False)
+ data_out = self.set_input_data(output, copy=False)
+ self.plan_inverse.update_arrays(
+ data_in, data_out
+ ) # TODO why in/out when it is out/in everywhere else in the function
+ # execute.__call__ does both update_arrays() and normalization
+ self.plan_inverse( # [1] --> ifft
+ ortho=self.fftw_norm_mode[1]["ortho"],
+ normalise_idft=self.fftw_norm_mode[1]["normalise_idft"],
+ )
+ self.plan_inverse.update_arrays(self.refs["data_out"], self.refs["data_in"])
+ return data_out
+
+
+def get_wisdom_metadata():
+ """
+ Get metadata on the current platform.
+ FFTW wisdom works with varying performance depending on whether the plans are re-used
+ on the same machine/architecture/etc.
+ For more information: https://www.fftw.org/fftw3_doc/Caveats-in-Using-Wisdom.html
+ """
+ return {
+ # "venv"
+ "executable": sys_executable,
+ # encapsulates sys.platform, platform.machine(), platform.architecture(), platform.libc_ver(), ...
+ "hostname": gethostname(),
+ "available_threads": len(os.sched_getaffinity(0)),
+ }
+
+
+def export_wisdom(fname, on_existing="overwrite"):
+ """
+ Export the current FFTW wisdom to a file.
+
+ :param str fname:
+ Path to the file where the wisdom is to be exported
+ :param str on_existing:
+ What do do when the target file already exists.
+ Possible options are:
+ - raise: raise an error and exit
+ - overwrite: overwrite the file with the current wisdom
+ - append: Import the already existing wisdom, and dump the newly combined wisdom to this file
+ """
+ if os.path.isfile(fname):
+ if on_existing == "raise":
+ raise ValueError("File already exists: %s" % fname)
+ if on_existing == "append":
+ import_wisdom(fname, on_mismatch="ignore") # ?
+ current_wisdom = pyfftw.export_wisdom()
+ res = get_wisdom_metadata()
+ for i, w in enumerate(current_wisdom):
+ res[str(i)] = np.array(w)
+ np.savez_compressed(fname, **res)
+
+
+def import_wisdom(fname, match=["hostname"], on_mismatch="warn"):
+ """
+ Import FFTW wisdom for a .npz file.
+
+ :param str fname:
+ Path to the .npz file containing FFTW wisdom
+ :param list match:
+ List of elements that must match when importing wisdom.
+ If match=["hostname"] (default), this class will only load wisdom that was saved
+ on the current machine, and discard everything else.
+ If match=["hostname", "executable"], wisdom will only be loaded if the file was
+ created on the same machine and by the same python executable.
+ :param str on_mismatch:
+ What to do when the file wisdom does not match the current platform.
+ Available options:
+ - "raise": raise an error (crash)
+ - "warn": print a warning, don't crash
+ - "ignore": do nothing
+ """
+
+ def handle_mismatch(item, loaded_value, current_value):
+ msg = (
+ "Platform configuration mismatch: %s: currently have '%s', loaded '%s'"
+ % (item, current_value, loaded_value)
+ )
+ if on_mismatch == "raise":
+ raise ValueError(msg)
+ if on_mismatch == "warn":
+ print(msg)
+
+ wis_metadata = get_wisdom_metadata()
+ loaded_wisdom = np.load(fname)
+ for metadata_name in match:
+ if metadata_name not in wis_metadata:
+ raise ValueError(
+ "Cannot match metadata '%s'. Available are: %s"
+ % (match, str(wis_metadata.keys()))
+ )
+ if loaded_wisdom[metadata_name] != wis_metadata[metadata_name]:
+ handle_mismatch(
+ metadata_name, loaded_wisdom[metadata_name], wis_metadata[metadata_name]
+ )
+ return
+ w = tuple(
+ loaded_wisdom[k][()]
+ for k in loaded_wisdom.keys()
+ if k not in wis_metadata.keys()
+ )
+ pyfftw.import_wisdom(w)
+
+
+def get_wisdom_file(
+ directory=None,
+ name_template="fftw_wisdom_{whoami}_{hostname}.npz",
+ create_dirs=True,
+):
+ """
+ Get a file path for storing FFTW wisdom.
+
+ :param str directory:
+ Directory where the file is created. By default, files are written in a temporary directory.
+ :param str name_template:
+ File name pattern. The following patterns can be used:
+ - {whoami}: current username
+ - {hostname}: machine name
+ :param bool create_dirs:
+ Whether to create (possibly nested) directories if needed.
+ """
+ directory = directory or gettempdir()
+ file_basename = name_template.format(whoami=os.getlogin(), hostname=gethostname())
+ out_file = os.path.join(directory, file_basename)
+ if create_dirs:
+ Path(os.path.dirname(out_file)).mkdir(parents=True, exist_ok=True)
+ return out_file
diff --git a/src/silx/math/fft/npfft.py b/src/silx/math/fft/npfft.py
new file mode 100644
index 0000000..3fe0754
--- /dev/null
+++ b/src/silx/math/fft/npfft.py
@@ -0,0 +1,144 @@
+#!/usr/bin/env python
+# /*##########################################################################
+#
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+import numpy as np
+import warnings
+from packaging.version import Version
+
+from .basefft import BaseFFT
+
+
+class NPFFT(BaseFFT):
+ """Initialize a numpy plan.
+
+ Please see FFT class for parameters help.
+ """
+
+ def __init__(
+ self,
+ shape=None,
+ dtype=None,
+ template=None,
+ shape_out=None,
+ axes=None,
+ normalize="rescale",
+ ):
+ super().__init__(
+ shape=shape,
+ dtype=dtype,
+ template=template,
+ shape_out=shape_out,
+ axes=axes,
+ normalize=normalize,
+ )
+ self.backend = "numpy"
+ self.real_transform = False
+ if template is not None and np.isrealobj(template):
+ self.real_transform = True
+ # For numpy functions.
+
+ self.set_fft_norm()
+ self.set_fft_functions()
+ self.compute_plans()
+
+ def set_fft_norm(self):
+ # backward, forward indicates the direction in which the
+ # normalisation is done. default is "backward"
+
+ # rescale is default norm with numpy, no need of keywords
+ # if normalize == "rescale": # normalisation 1/N on ifft
+ self.numpy_args_fft = {}
+ self.numpy_args_ifft = {}
+
+ if self.normalize == "ortho": # normalization 1/sqrt(N) on both fft & ifft
+ self.numpy_args_fft = {"norm": "ortho"}
+ self.numpy_args_ifft = {"norm": "ortho"}
+
+ elif self.normalize == "none": # no normalisation on both fft & ifft
+ if Version(np.version.version) < Version("1.20"):
+ # "backward" & "forward" keywords were introduced in 1.20 and we support numpy >= 1.8
+ warnings.warn(
+ "Numpy version %s does not allow to non-normalization. Effective normalization will be 'rescale'"
+ % (np.version.version)
+ ) # default 'rescale' normalization
+ else:
+ self.numpy_args_fft = {"norm": "backward"}
+ self.numpy_args_ifft = {"norm": "forward"}
+
+ def set_fft_functions(self):
+ # (fwd, inv) = _fft_functions[is_real][ndim]
+ self._fft_functions = {
+ True: {
+ 1: (np.fft.rfft, np.fft.irfft),
+ 2: (np.fft.rfft2, np.fft.irfft2),
+ 3: (np.fft.rfftn, np.fft.irfftn),
+ },
+ False: {
+ 1: (np.fft.fft, np.fft.ifft),
+ 2: (np.fft.fft2, np.fft.ifft2),
+ 3: (np.fft.fftn, np.fft.ifftn),
+ },
+ }
+
+ def _allocate(self, shape, dtype):
+ return np.zeros(shape, dtype=dtype)
+
+ def compute_plans(self):
+ ndim = len(self.shape)
+ funcs = self._fft_functions[self.real_transform][np.minimum(ndim, 3)]
+
+ # Set norm
+ # self.numpy_args_fft & self.numpy_args_ifft already set in set_fft_norm
+
+ # Batched transform
+ if (self.user_axes is not None) and len(self.user_axes) < ndim:
+ funcs = self._fft_functions[self.real_transform][np.minimum(ndim - 1, 3)]
+ self.numpy_args_fft["axes"] = self.user_axes
+ self.numpy_args_ifft["axes"] = self.user_axes
+ # Special case of batched 1D transform on 2D data
+ if ndim == 2:
+ assert len(self.user_axes) == 1
+ self.numpy_args_fft["axis"] = self.user_axes[0]
+ self.numpy_args_fft.pop("axes")
+ self.numpy_args_ifft["axis"] = self.user_axes[0]
+ self.numpy_args_ifft.pop("axes")
+ self.numpy_funcs = funcs
+
+ def fft(self, array):
+ """
+ Perform a (forward) Fast Fourier Transform.
+
+ :param numpy.ndarray array:
+ Input data. Must be consistent with the current context.
+ """
+ return self.numpy_funcs[0](array, **self.numpy_args_fft)
+
+ def ifft(self, array):
+ """
+ Perform a (inverse) Fast Fourier Transform.
+
+ :param numpy.ndarray array:
+ Input data. Must be consistent with the current context.
+ """
+ return self.numpy_funcs[1](array, **self.numpy_args_ifft)
diff --git a/src/silx/math/fft/test/__init__.py b/src/silx/math/fft/test/__init__.py
new file mode 100644
index 0000000..d076ee3
--- /dev/null
+++ b/src/silx/math/fft/test/__init__.py
@@ -0,0 +1,22 @@
+# /*##########################################################################
+# Copyright (C) 2016-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
diff --git a/src/silx/math/fft/test/test_fft.py b/src/silx/math/fft/test/test_fft.py
new file mode 100644
index 0000000..abe7842
--- /dev/null
+++ b/src/silx/math/fft/test/test_fft.py
@@ -0,0 +1,423 @@
+#!/usr/bin/env python
+# /*##########################################################################
+#
+# Copyright (c) 2018-2022 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test of the FFT module"""
+
+from os import path
+import logging
+import numpy as np
+import unittest
+from packaging.version import Version
+import pytest
+from tempfile import TemporaryDirectory
+
+try:
+ try:
+ from scipy.misc import ascent
+ except:
+ from scipy.datasets import ascent
+ __have_scipy = True
+except ImportError:
+ __have_scipy = False
+from silx.utils.testutils import ParametricTestCase
+from silx.math.fft.fft import FFT
+from silx.math.fft.clfft import __have_clfft__
+from silx.math.fft.cufft import __have_cufft__
+from silx.math.fft.fftw import (
+ __have_fftw__,
+ import_wisdom,
+ export_wisdom,
+ get_wisdom_file,
+)
+
+if __have_cufft__:
+ import atexit
+ import pycuda.driver as cuda
+ from pycuda.tools import clear_context_caches
+
+
+def get_cuda_context(device_id=None, cleanup_at_exit=True):
+ """
+ Create or get a CUDA context.
+ """
+ current_ctx = cuda.Context.get_current()
+ # If a context already exists, use this one
+ # TODO what if the device used is different from device_id ?
+ if current_ctx is not None:
+ return current_ctx
+ # Otherwise create a new context
+ cuda.init()
+
+ if device_id is None:
+ device_id = 0
+ # Use the Context obtained by retaining the device's primary context,
+ # which is the one used by the CUDA runtime API (ex. scikit-cuda).
+ # Unlike Context.make_context(), the newly-created context is not made current.
+ context = cuda.Device(device_id).retain_primary_context()
+ context.push()
+
+ # Register a clean-up function at exit
+ def _finish_up(context):
+ if context is not None:
+ context.pop()
+ context = None
+ clear_context_caches()
+
+ if cleanup_at_exit:
+ atexit.register(_finish_up, context)
+ return context
+
+
+logger = logging.getLogger(__name__)
+
+
+class TransformInfos(object):
+ def __init__(self):
+ self.dimensions = [
+ "1D",
+ "batched_1D",
+ "2D",
+ "batched_2D",
+ "3D",
+ ]
+ self.modes = {
+ "R2C": np.float32,
+ "R2C_double": np.float64,
+ "C2C": np.complex64,
+ "C2C_double": np.complex128,
+ }
+ self.sizes = {
+ "1D": [(128,), (127,)],
+ "2D": [(128, 128), (128, 127), (127, 128), (127, 127)],
+ "3D": [
+ (64, 64, 64),
+ (64, 64, 63),
+ (64, 63, 64),
+ (63, 64, 64),
+ (64, 63, 63),
+ (63, 64, 63),
+ (63, 63, 64),
+ (63, 63, 63),
+ ],
+ }
+ self.axes = {
+ "1D": None,
+ "batched_1D": (-1,),
+ "2D": None,
+ "batched_2D": (-2, -1),
+ "3D": None,
+ }
+ self.sizes["batched_1D"] = self.sizes["2D"]
+ self.sizes["batched_2D"] = self.sizes["3D"]
+
+
+class Data(object):
+ def __init__(self):
+ self.data = ascent().astype("float32")
+ self.data1d = self.data[:, 0] # non-contiguous data
+ self.data3d = np.tile(self.data[:64, :64], (64, 1, 1))
+ self.data_refs = {
+ 1: self.data1d,
+ 2: self.data,
+ 3: self.data3d,
+ }
+
+
+@unittest.skipUnless(__have_scipy, "scipy is missing")
+@pytest.mark.usefixtures("test_options_class_attr")
+class TestFFT(ParametricTestCase):
+ """Test cuda/opencl/fftw backends of FFT"""
+
+ def setUp(self):
+ self.tol = {
+ np.dtype("float32"): 1e-3,
+ np.dtype("float64"): 1e-9,
+ np.dtype("complex64"): 1e-3,
+ np.dtype("complex128"): 1e-9,
+ }
+ self.transform_infos = TransformInfos()
+ self.test_data = Data()
+
+ @staticmethod
+ def calc_mae(arr1, arr2):
+ """
+ Compute the Max Absolute Error between two arrays
+ """
+ return np.max(np.abs(arr1 - arr2))
+
+ @unittest.skipIf(
+ not __have_cufft__, "cuda back-end requires pycuda and scikit-cuda"
+ )
+ def test_cuda(self):
+ get_cuda_context()
+
+ # Error is higher when using cuda. fast_math mode ?
+ self.tol[np.dtype("float32")] *= 2
+
+ self.__run_tests(backend="cuda")
+
+ @unittest.skipIf(not __have_clfft__, "opencl back-end requires pyopencl and gpyfft")
+ def test_opencl(self):
+ from silx.opencl.common import ocl
+
+ if ocl is not None:
+ self.__run_tests(backend="opencl", ctx=ocl.create_context())
+
+ @unittest.skipIf(not __have_fftw__, "fftw back-end requires pyfftw")
+ def test_fftw(self):
+ self.__run_tests(backend="fftw")
+
+ def __run_tests(self, backend, **extra_args):
+ """Run all tests with the given backend
+
+ :param str backend:
+ :param dict extra_args: Additional arguments to provide to FFT
+ """
+ for trdim in self.transform_infos.dimensions:
+ for mode in self.transform_infos.modes:
+ for size in self.transform_infos.sizes[trdim]:
+ with self.subTest(trdim=trdim, mode=mode, size=size):
+ self.__test(backend, trdim, mode, size, **extra_args)
+
+ def __test(self, backend, trdim, mode, size, **extra_args):
+ """Compare given backend with numpy for given conditions"""
+ logger.debug(
+ "backend: %s, trdim: %s, mode: %s, size: %s",
+ backend,
+ trdim,
+ mode,
+ str(size),
+ )
+ if size == "3D" and self.test_options.TEST_LOW_MEM:
+ self.skipTest("low mem")
+
+ ndim = len(size)
+ input_data = self.test_data.data_refs[ndim].astype(
+ self.transform_infos.modes[mode]
+ )
+ tol = self.tol[np.dtype(input_data.dtype)]
+ if trdim == "3D":
+ tol *= 10 # Error is relatively high in high dimensions
+ # It seems that cuda has problems with C2D batched 1D
+ if trdim == "batched_1D" and backend == "cuda" and mode == "C2C":
+ tol *= 10
+
+ # Python < 3.5 does not want to mix **extra_args with existing kwargs
+ fft_args = {
+ "template": input_data,
+ "axes": self.transform_infos.axes[trdim],
+ "backend": backend,
+ }
+ fft_args.update(extra_args)
+ F = FFT(**fft_args)
+ F_np = FFT(
+ template=input_data, axes=self.transform_infos.axes[trdim], backend="numpy"
+ )
+
+ # Forward FFT
+ res = F.fft(input_data)
+ res_np = F_np.fft(input_data)
+ mae = self.calc_mae(res, res_np)
+ all_close = (np.allclose(res, res_np, atol=tol, rtol=tol),)
+ self.assertTrue(
+ all_close,
+ "FFT %s:%s, MAE(%s, numpy) = %f (tol = %.2e)"
+ % (mode, trdim, backend, mae, tol),
+ )
+
+ # Inverse FFT
+ res2 = F.ifft(res)
+ mae = self.calc_mae(res2, input_data)
+ self.assertTrue(
+ mae < tol, "IFFT %s:%s, MAE(%s, numpy) = %f" % (mode, trdim, backend, mae)
+ )
+
+ # Test normalizations. silx FFT has three normalization modes:
+ # - "rescale" (default). FFT is unscaled, IFFT is scaled by 1/N.
+ # This corresponds to numpy normalize=None i.e normalize="backward"
+ # - "ortho": FFT/IFFT are both scaled with 1/sqrt(N) so that FFT is unitary.
+ # - "none": Neither FFT nor IFFT are not scaled, so IFFT(FFT(array)) = N*array
+
+ norms_backends_support = {
+ "numpy": {
+ "supported_normalizations": ["rescale", "ortho", "none"],
+ },
+ "fftw": {
+ "supported_normalizations": ["rescale", "ortho", "none"],
+ },
+ "opencl": {
+ "supported_normalizations": ["rescale"],
+ },
+ "cuda": {
+ "supported_normalizations": ["rescale", "none"],
+ },
+ }
+
+ @staticmethod
+ def _compute_numpy_normalized_fft(data, axes, silx_normalization_mode):
+ if silx_normalization_mode in ["rescale", "none"]:
+ return np.fft.rfftn(data, axes=axes, norm=None)
+ elif silx_normalization_mode == "ortho":
+ return np.fft.rfftn(data, axes=axes, norm="ortho")
+ else:
+ raise ValueError("Unknown normalization mode %s" % silx_normalization_mode)
+
+ @staticmethod
+ def _compute_numpy_normalized_ifft(data, axes, silx_normalization_mode):
+ if silx_normalization_mode == "rescale":
+ return np.fft.irfftn(data, axes=axes, norm=None)
+ elif silx_normalization_mode == "ortho":
+ return np.fft.irfftn(data, axes=axes, norm="ortho")
+ elif silx_normalization_mode == "none":
+ res = np.fft.irfftn(data, axes=axes, norm=None)
+ # This assumes a FFT on all the axes, won't work on batched FFT
+ N = res.size
+ return res * N
+ else:
+ raise ValueError("Unknown normalization mode %s" % silx_normalization_mode)
+
+ @unittest.skipIf(not __have_fftw__, "fftw back-end requires pyfftw")
+ def test_norms_fftw(self):
+ return self._test_norms_with_backend("fftw")
+
+ @unittest.skipIf(
+ Version(np.version.version) <= Version("1.19.5"),
+ "normalization does not work for numpy <= 1.19.5",
+ )
+ def test_norms_numpy(self):
+ return self._test_norms_with_backend("numpy")
+
+ @unittest.skipIf(not __have_clfft__, "opencl back-end requires pyopencl and gpyfft")
+ def test_norms_opencl(self):
+ from silx.opencl.common import ocl
+
+ if ocl is not None:
+ return self._test_norms_with_backend("opencl")
+
+ @unittest.skipIf(
+ not __have_cufft__, "cuda back-end requires pycuda and scikit-cuda"
+ )
+ def test_norms_cuda(self):
+ get_cuda_context()
+ return self._test_norms_with_backend("cuda")
+
+ def _test_norms_with_backend(self, backend_name):
+ backend_params = self.norms_backends_support[backend_name]
+
+ data = self.test_data.data
+ tol = self.tol[np.dtype(data.dtype)]
+
+ for norm in backend_params["supported_normalizations"]:
+ fft = FFT(template=data, backend=backend_name, normalize=norm)
+ res = fft.fft(data)
+ ref = self._compute_numpy_normalized_fft(data, fft.axes, norm)
+ assert np.allclose(
+ res, ref, atol=tol, rtol=tol
+ ), "Something wrong with %s norm=%s" % (backend_name, norm)
+
+ res2 = fft.ifft(res)
+ ref2 = self._compute_numpy_normalized_ifft(ref, fft.axes, norm)
+ # unscaled IFFT yields very large values. Use a relatively high "atol"
+ assert np.allclose(
+ res2, ref2, atol=res2.max() / 1e6
+ ), "Something wrong with I%s norm=%s" % (backend_name, norm)
+
+
+@unittest.skipUnless(__have_scipy, "scipy is missing")
+class TestNumpyFFT(ParametricTestCase):
+ """
+ Test the Numpy backend individually.
+ """
+
+ def setUp(self):
+ transforms = {
+ "1D": {
+ True: (np.fft.rfft, np.fft.irfft),
+ False: (np.fft.fft, np.fft.ifft),
+ },
+ "2D": {
+ True: (np.fft.rfft2, np.fft.irfft2),
+ False: (np.fft.fft2, np.fft.ifft2),
+ },
+ "3D": {
+ True: (np.fft.rfftn, np.fft.irfftn),
+ False: (np.fft.fftn, np.fft.ifftn),
+ },
+ }
+ transforms["batched_1D"] = transforms["1D"]
+ transforms["batched_2D"] = transforms["2D"]
+ self.transforms = transforms
+ self.transform_infos = TransformInfos()
+ self.test_data = Data()
+
+ def test(self):
+ """Test the numpy backend against native fft.
+
+ Results should be exactly the same.
+ """
+ for trdim in self.transform_infos.dimensions:
+ for mode in self.transform_infos.modes:
+ for size in self.transform_infos.sizes[trdim]:
+ with self.subTest(trdim=trdim, mode=mode, size=size):
+ self.__test(trdim, mode, size)
+
+ def __test(self, trdim, mode, size):
+ logger.debug("trdim: %s, mode: %s, size: %s", trdim, mode, str(size))
+ ndim = len(size)
+ input_data = self.test_data.data_refs[ndim].astype(
+ self.transform_infos.modes[mode]
+ )
+ np_fft, np_ifft = self.transforms[trdim][np.isrealobj(input_data)]
+
+ F = FFT(
+ template=input_data, axes=self.transform_infos.axes[trdim], backend="numpy"
+ )
+ # Test FFT
+ res = F.fft(input_data)
+ ref = np_fft(input_data)
+ self.assertTrue(np.allclose(res, ref))
+
+ # Test IFFT
+ res2 = F.ifft(res)
+ ref2 = np_ifft(ref)
+ self.assertTrue(np.allclose(res2, ref2))
+
+
+@pytest.mark.skipif(not (__have_fftw__), reason="Need fftw/pyfftw for this test")
+def test_fftw_wisdom():
+ """
+ Test FFTW wisdom import/export mechanism
+ """
+
+ assert path.isdir(path.dirname(get_wisdom_file())) # Default: tempdir.gettempdir()
+
+ with TemporaryDirectory(prefix="test_fftw_wisdom") as dname:
+ subdir = path.join(dname, "subdir")
+ get_wisdom_file(directory=subdir, create_dirs=False)
+ assert not (path.isdir(subdir))
+ fname = get_wisdom_file(directory=subdir, create_dirs=True)
+ assert path.isdir(subdir)
+ export_wisdom(fname)
+ assert path.isfile(fname)
+ import_wisdom(fname)