summaryrefslogtreecommitdiff
path: root/src/silx/math/fft/basefft.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/silx/math/fft/basefft.py')
-rw-r--r--src/silx/math/fft/basefft.py146
1 files changed, 146 insertions, 0 deletions
diff --git a/src/silx/math/fft/basefft.py b/src/silx/math/fft/basefft.py
new file mode 100644
index 0000000..854ca37
--- /dev/null
+++ b/src/silx/math/fft/basefft.py
@@ -0,0 +1,146 @@
+#!/usr/bin/env python
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+import numpy as np
+from pkg_resources import parse_version
+
+
+def check_version(package, required_version):
+ """
+ Check whether a given package version is superior or equal to required_version.
+ """
+ try:
+ ver = getattr(package, "__version__")
+ except AttributeError:
+ try:
+ ver = getattr(package, "version")
+ except Exception:
+ return False
+ req_v = parse_version(required_version)
+ ver_v = parse_version(ver)
+ return ver_v >= req_v
+
+
+class BaseFFT(object):
+ """
+ Base class for all FFT backends.
+ """
+ def __init__(self, **kwargs):
+ self.__get_args(**kwargs)
+
+ if self.shape is None and self.dtype is None and self.template is None:
+ raise ValueError("Please provide either (shape and dtype) or template")
+ if self.template is not None:
+ self.shape = self.template.shape
+ self.dtype = self.template.dtype
+ self.user_data = self.template
+ self.data_allocated = False
+ self.__calc_axes()
+ self.__set_dtypes()
+ self.__calc_shape()
+
+ def __get_args(self, **kwargs):
+ expected_args = {
+ "shape": None,
+ "dtype": None,
+ "template": None,
+ "shape_out": None,
+ "axes": None,
+ "normalize": "rescale",
+ }
+ for arg_name, default_val in expected_args.items():
+ if arg_name not in kwargs:
+ # Base class was not instantiated properly
+ raise ValueError("Please provide argument %s" % arg_name)
+ setattr(self, arg_name, default_val)
+ for arg_name, arg_val in kwargs.items():
+ setattr(self, arg_name, arg_val)
+
+ def __set_dtypes(self):
+ dtypes_mapping = {
+ np.dtype("float32"): np.complex64,
+ np.dtype("float64"): np.complex128,
+ np.dtype("complex64"): np.complex64,
+ np.dtype("complex128"): np.complex128
+ }
+ dp = {
+ np.dtype("float32"): np.float64,
+ np.dtype("complex64"): np.complex128
+ }
+ self.dtype_in = np.dtype(self.dtype)
+ if self.dtype_in not in dtypes_mapping:
+ raise ValueError("Invalid input data type: got %s" %
+ self.dtype_in
+ )
+ self.dtype_out = dtypes_mapping[self.dtype_in]
+
+ def __calc_shape(self):
+ # TODO allow for C2C even for real input data (?)
+ if self.dtype_in in [np.float32, np.float64]:
+ last_dim = self.shape[-1]//2 + 1
+ # FFTW convention
+ self.shape_out = self.shape[:-1] + (self.shape[-1]//2 + 1,)
+ else:
+ self.shape_out = self.shape
+
+ def __calc_axes(self):
+ default_axes = tuple(range(len(self.shape)))
+ if self.axes is None:
+ self.axes = default_axes
+ self.user_axes = None
+ else:
+ self.user_axes = self.axes
+ # Handle possibly negative axes
+ self.axes = tuple(np.array(default_axes)[np.array(self.user_axes)])
+
+ def _allocate(self, shape, dtype):
+ raise ValueError("This should be implemented by back-end FFT")
+
+ def set_data(self, dst, src, shape, dtype, copy=True):
+ raise ValueError("This should be implemented by back-end FFT")
+
+ def allocate_arrays(self):
+ if not(self.data_allocated):
+ self.data_in = self._allocate(self.shape, self.dtype_in)
+ self.data_out = self._allocate(self.shape_out, self.dtype_out)
+ self.data_allocated = True
+
+ def set_input_data(self, data, copy=True):
+ if data is None:
+ return self.data_in
+ else:
+ return self.set_data(self.data_in, data, self.shape, self.dtype_in, copy=copy, name="data_in")
+
+ def set_output_data(self, data, copy=True):
+ if data is None:
+ return self.data_out
+ else:
+ return self.set_data(self.data_out, data, self.shape_out, self.dtype_out, copy=copy, name="data_out")
+
+ def fft(self, array, **kwargs):
+ raise ValueError("This should be implemented by back-end FFT")
+
+ def ifft(self, array, **kwargs):
+ raise ValueError("This should be implemented by back-end FFT")