summaryrefslogtreecommitdiff
path: root/silx/math/fft/npfft.py
diff options
context:
space:
mode:
Diffstat (limited to 'silx/math/fft/npfft.py')
-rw-r--r--silx/math/fft/npfft.py124
1 files changed, 124 insertions, 0 deletions
diff --git a/silx/math/fft/npfft.py b/silx/math/fft/npfft.py
new file mode 100644
index 0000000..20351de
--- /dev/null
+++ b/silx/math/fft/npfft.py
@@ -0,0 +1,124 @@
+#!/usr/bin/env python
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+import numpy as np
+
+from .basefft import BaseFFT
+
+
+class NPFFT(BaseFFT):
+ """Initialize a numpy plan.
+
+ Please see FFT class for parameters help.
+ """
+ def __init__(
+ self,
+ shape=None,
+ dtype=None,
+ template=None,
+ shape_out=None,
+ axes=None,
+ normalize="rescale",
+ ):
+ super(NPFFT, self).__init__(
+ shape=shape,
+ dtype=dtype,
+ template=template,
+ shape_out=shape_out,
+ axes=axes,
+ normalize=normalize,
+ )
+ self.backend = "numpy"
+ self.real_transform = False
+ if template is not None and np.isrealobj(template):
+ self.real_transform = True
+ # For numpy functions.
+ # TODO Issue warning if user wants ifft(fft(data)) = N*data ?
+ if normalize != "ortho":
+ self.normalize = None
+ self.set_fft_functions()
+ #~ self.allocate_arrays() # not needed for this backend
+ self.compute_plans()
+
+
+ def set_fft_functions(self):
+ # (fwd, inv) = _fft_functions[is_real][ndim]
+ self._fft_functions = {
+ True: {
+ 1: (np.fft.rfft, np.fft.irfft),
+ 2: (np.fft.rfft2, np.fft.irfft2),
+ 3: (np.fft.rfftn, np.fft.irfftn),
+ },
+ False: {
+ 1: (np.fft.fft, np.fft.ifft),
+ 2: (np.fft.fft2, np.fft.ifft2),
+ 3: (np.fft.fftn, np.fft.ifftn),
+ }
+ }
+
+
+ def _allocate(self, shape, dtype):
+ return np.zeros(self.queue, shape, dtype=dtype)
+
+
+ def compute_plans(self):
+ ndim = len(self.shape)
+ funcs = self._fft_functions[self.real_transform][np.minimum(ndim, 3)]
+ if np.version.version[:4] in ["1.8.", "1.9."]:
+ # norm keyword was introduced in 1.10 and we support numpy >= 1.8
+ self.numpy_args = {}
+ else:
+ self.numpy_args = {"norm": self.normalize}
+ # Batched transform
+ if (self.user_axes is not None) and len(self.user_axes) < ndim:
+ funcs = self._fft_functions[self.real_transform][np.minimum(ndim-1, 3)]
+ self.numpy_args["axes"] = self.user_axes
+ # Special case of batched 1D transform on 2D data
+ if ndim == 2:
+ assert len(self.user_axes) == 1
+ self.numpy_args["axis"] = self.user_axes[0]
+ self.numpy_args.pop("axes")
+ self.numpy_funcs = funcs
+
+
+ def fft(self, array):
+ """
+ Perform a (forward) Fast Fourier Transform.
+
+ :param numpy.ndarray array:
+ Input data. Must be consistent with the current context.
+ """
+ return self.numpy_funcs[0](array, **self.numpy_args)
+
+
+ def ifft(self, array):
+ """
+ Perform a (inverse) Fast Fourier Transform.
+
+ :param numpy.ndarray array:
+ Input data. Must be consistent with the current context.
+ """
+ return self.numpy_funcs[1](array, **self.numpy_args)
+