diff options
Diffstat (limited to 'src/silx/math/fft/basefft.py')
-rw-r--r-- | src/silx/math/fft/basefft.py | 37 |
1 files changed, 21 insertions, 16 deletions
diff --git a/src/silx/math/fft/basefft.py b/src/silx/math/fft/basefft.py index c608fde..6e9fac8 100644 --- a/src/silx/math/fft/basefft.py +++ b/src/silx/math/fft/basefft.py @@ -23,7 +23,7 @@ # # ###########################################################################*/ import numpy as np -from pkg_resources import parse_version +from packaging.version import Version def check_version(package, required_version): @@ -37,8 +37,8 @@ def check_version(package, required_version): ver = getattr(package, "version") except Exception: return False - req_v = parse_version(required_version) - ver_v = parse_version(ver) + req_v = Version(required_version) + ver_v = Version(ver) return ver_v >= req_v @@ -46,6 +46,7 @@ class BaseFFT(object): """ Base class for all FFT backends. """ + def __init__(self, **kwargs): self.__get_args(**kwargs) @@ -82,25 +83,20 @@ class BaseFFT(object): np.dtype("float32"): np.complex64, np.dtype("float64"): np.complex128, np.dtype("complex64"): np.complex64, - np.dtype("complex128"): np.complex128 - } - dp = { - np.dtype("float32"): np.float64, - np.dtype("complex64"): np.complex128 + np.dtype("complex128"): np.complex128, } + dp = {np.dtype("float32"): np.float64, np.dtype("complex64"): np.complex128} self.dtype_in = np.dtype(self.dtype) if self.dtype_in not in dtypes_mapping: - raise ValueError("Invalid input data type: got %s" % - self.dtype_in - ) + raise ValueError("Invalid input data type: got %s" % self.dtype_in) self.dtype_out = dtypes_mapping[self.dtype_in] def __calc_shape(self): # TODO allow for C2C even for real input data (?) if self.dtype_in in [np.float32, np.float64]: - last_dim = self.shape[-1]//2 + 1 + last_dim = self.shape[-1] // 2 + 1 # FFTW convention - self.shape_out = self.shape[:-1] + (self.shape[-1]//2 + 1,) + self.shape_out = self.shape[:-1] + (self.shape[-1] // 2 + 1,) else: self.shape_out = self.shape @@ -121,7 +117,7 @@ class BaseFFT(object): raise ValueError("This should be implemented by back-end FFT") def allocate_arrays(self): - if not(self.data_allocated): + if not (self.data_allocated): self.data_in = self._allocate(self.shape, self.dtype_in) self.data_out = self._allocate(self.shape_out, self.dtype_out) self.data_allocated = True @@ -130,13 +126,22 @@ class BaseFFT(object): if data is None: return self.data_in else: - return self.set_data(self.data_in, data, self.shape, self.dtype_in, copy=copy, name="data_in") + return self.set_data( + self.data_in, data, self.shape, self.dtype_in, copy=copy, name="data_in" + ) def set_output_data(self, data, copy=True): if data is None: return self.data_out else: - return self.set_data(self.data_out, data, self.shape_out, self.dtype_out, copy=copy, name="data_out") + return self.set_data( + self.data_out, + data, + self.shape_out, + self.dtype_out, + copy=copy, + name="data_out", + ) def fft(self, array, **kwargs): raise ValueError("This should be implemented by back-end FFT") |