summaryrefslogtreecommitdiff
path: root/tests/test_lineshapes.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_lineshapes.py')
-rw-r--r--tests/test_lineshapes.py124
1 files changed, 124 insertions, 0 deletions
diff --git a/tests/test_lineshapes.py b/tests/test_lineshapes.py
new file mode 100644
index 0000000..8c440f4
--- /dev/null
+++ b/tests/test_lineshapes.py
@@ -0,0 +1,124 @@
+"""Tests for lineshape functions."""
+
+import inspect
+
+import numpy as np
+import pytest
+
+import lmfit
+
+
+@pytest.mark.parametrize("lineshape", lmfit.lineshapes.functions)
+def test_no_ZeroDivisionError_and_finite_output(lineshape):
+ """Tests for finite output and ZeroDivisionError is not raised."""
+ xvals = np.linspace(0, 10, 100)
+
+ func = getattr(lmfit.lineshapes, lineshape)
+ assert callable(func)
+ sig = inspect.signature(func)
+
+ # set the following function arguments:
+ # x = xvals
+ # center = 0.5*(max(xvals)-min(xvals))
+ # center1 = 0.25*(max(xvals)-min(xvals))
+ # center2 = 0.75*(max(xvals)-min(xvals))
+ # form = default value (i.e., 'linear' or 'bose')
+ xvals_mid_range = xvals.mean()
+ zero_pars = [par_name for par_name in sig.parameters.keys() if par_name
+ not in ('x', 'form')]
+
+ for par_zero in zero_pars:
+ fnc_args = []
+ for par in sig.parameters.keys():
+ if par == 'x':
+ fnc_args.append(xvals)
+ elif par == 'center':
+ fnc_args.append(0.5*xvals_mid_range)
+ elif par == 'center1':
+ fnc_args.append(0.25*xvals_mid_range)
+ elif par == 'center2':
+ fnc_args.append(0.75*xvals_mid_range)
+ elif par == par_zero:
+ fnc_args.append(0.0)
+ else:
+ fnc_args.append(sig.parameters[par].default)
+
+ fnc_output = func(*fnc_args)
+ assert len(xvals) == len(fnc_output)
+ assert np.all(np.isfinite(fnc_output))
+
+
+@pytest.mark.parametrize("lineshape", lmfit.lineshapes.functions)
+def test_x_float_value(lineshape):
+ """Test lineshapes when x is not an array but a float."""
+ xval = 7.0
+
+ func = getattr(lmfit.lineshapes, lineshape)
+ sig = inspect.signature(func)
+
+ fnc_args = [xval]
+
+ for par in [par_name for par_name in sig.parameters.keys()
+ if par_name != 'x']:
+ fnc_args.append(sig.parameters[par].default)
+
+ if lineshape in ('step', 'rectangle'):
+ msg = r"'float' object does not support item assignment"
+ with pytest.raises(TypeError, match=msg):
+ fnc_output = func(*fnc_args)
+ else:
+ fnc_output = func(*fnc_args)
+ assert isinstance(fnc_output, float)
+
+
+rising_form = ['erf', 'logistic', 'atan', 'arctan', 'linear', 'unknown']
+@pytest.mark.parametrize("form", rising_form)
+@pytest.mark.parametrize("lineshape", ['step', 'rectangle'])
+def test_form_argument_step_rectangle(form, lineshape):
+ """Test 'form' argument for step- and rectangle-functions."""
+ xvals = np.linspace(0, 10, 100)
+
+ func = getattr(lmfit.lineshapes, lineshape)
+ sig = inspect.signature(func)
+
+ fnc_args = [xvals]
+ for par in [par_name for par_name in sig.parameters.keys()
+ if par_name != 'x']:
+ if par == 'form':
+ fnc_args.append(form)
+ else:
+ fnc_args.append(sig.parameters[par].default)
+
+ if form == 'unknown':
+ msg = r"Invalid value .* for argument .*; should be one of .*"
+ with pytest.raises(ValueError, match=msg):
+ func(*fnc_args)
+ else:
+ fnc_output = func(*fnc_args)
+ assert len(fnc_output) == len(xvals)
+
+
+thermal_form = ['bose', 'maxwell', 'fermi', 'Bose-Einstein', 'unknown']
+@pytest.mark.parametrize("form", thermal_form)
+def test_form_argument_thermal_distribution(form):
+ """Test 'form' argument for thermal_distribution function."""
+ xvals = np.linspace(0, 10, 100)
+
+ func = lmfit.lineshapes.thermal_distribution
+ sig = inspect.signature(lmfit.lineshapes.thermal_distribution)
+
+ fnc_args = [xvals]
+ for par in [par_name for par_name in sig.parameters.keys()
+ if par_name != 'x']:
+ if par == 'form':
+ fnc_args.append(form)
+ else:
+ fnc_args.append(sig.parameters[par].default)
+
+ if form == 'unknown':
+ msg = r"Invalid value .* for argument .*; should be one of .*"
+ with pytest.raises(ValueError, match=msg):
+ func(*fnc_args)
+ else:
+ fnc_output = func(*fnc_args)
+ assert len(fnc_output) == len(xvals)