summaryrefslogtreecommitdiff
path: root/lmfit/confidence.py
diff options
context:
space:
mode:
Diffstat (limited to 'lmfit/confidence.py')
-rw-r--r--lmfit/confidence.py70
1 files changed, 33 insertions, 37 deletions
diff --git a/lmfit/confidence.py b/lmfit/confidence.py
index 684d53d..6914a58 100644
--- a/lmfit/confidence.py
+++ b/lmfit/confidence.py
@@ -1,18 +1,17 @@
"""Contains functions to calculate confidence intervals."""
-from collections import OrderedDict
from warnings import warn
import numpy as np
-from scipy.optimize import brentq
+from scipy.optimize import root_scalar
from scipy.special import erf
from scipy.stats import f
from .minimizer import MinimizerException
CONF_ERR_GEN = 'Cannot determine Confidence Intervals'
-CONF_ERR_STDERR = '%s without sensible uncertainty estimates' % CONF_ERR_GEN
-CONF_ERR_NVARS = '%s with < 2 variables' % CONF_ERR_GEN
+CONF_ERR_STDERR = f'{CONF_ERR_GEN} without sensible uncertainty estimates'
+CONF_ERR_NVARS = f'{CONF_ERR_GEN} with < 2 variables'
def f_compare(best_fit, new_fit):
@@ -31,7 +30,7 @@ def f_compare(best_fit, new_fit):
Returns
-------
float
- Value of the calculated probality.
+ Value of the calculated probability.
"""
nfree = best_fit.nfree
@@ -213,7 +212,7 @@ class ConfidenceInterval:
def calc_all_ci(self):
"""Calculate all confidence intervals."""
- out = OrderedDict()
+ out = {}
for p in self.p_names:
out[p] = (self.calc_ci(p, -1)[::-1] +
@@ -234,14 +233,20 @@ class ConfidenceInterval:
para = self.params[para]
# function used to calculate the probability
- calc_prob = lambda val, prob: self.calc_prob(para, val, prob)
+ cache = {}
+
+ def calc_prob(val, target_prob):
+ if val not in cache:
+ cache[val] = self.calc_prob(para, val, 0)
+ return cache[val] - target_prob
+
if self.trace:
x = [i.value for i in self.params.values()]
self.trace_dict[para.name].append(x + [0])
para.vary = False
limit, max_prob = self.find_limit(para, direction)
- start_val = a_limit = float(para.value)
+ a_limit = float(para.value)
ret = []
orig_warn_settings = np.geterr()
np.seterr(all='ignore')
@@ -250,17 +255,12 @@ class ConfidenceInterval:
ret.append((prob, direction*np.inf))
continue
- try:
- val = brentq(calc_prob, a_limit,
- limit, rtol=.5e-4, args=prob)
- except ValueError:
- self.reset_vals()
- try:
- val = brentq(calc_prob, start_val,
- limit, rtol=.5e-4, args=prob)
- except ValueError:
- val = np.nan
-
+ sol = root_scalar(calc_prob, method='toms748', bracket=sorted([limit, a_limit]), rtol=.5e-4, args=(prob,))
+ if sol.converged:
+ val = sol.root
+ else:
+ val = np.nan
+ break
a_limit = val
ret.append((prob, val))
@@ -276,7 +276,7 @@ class ConfidenceInterval:
def find_limit(self, para, direction):
"""Find a value for given parameter so that prob(val) > sigmas."""
if self.verbose:
- print('Calculating CI for ' + para.name)
+ print(f'Calculating CI for {para.name}')
self.reset_vals()
# determine starting step
@@ -294,7 +294,7 @@ class ConfidenceInterval:
max_prob = max(self.probs)
while old_prob < max_prob:
- i = i + 1
+ i += 1
limit += step * direction
if limit > para.max:
limit = para.max
@@ -307,30 +307,26 @@ class ConfidenceInterval:
rel_change = (new_prob - old_prob) / max(new_prob, old_prob, 1e-12)
old_prob = new_prob
if self.verbose:
- msg = "P({}={}) = {}, max. prob={}"
- print(msg.format(para.name, limit, new_prob, max_prob))
+ print(f'P({para.name}={limit}) = {new_prob}, '
+ f'max. prob={max_prob}')
# check for convergence
- if bound_reached:
- if new_prob < max(self.probs):
- errmsg = ("Bound reached with "
- "prob({}={}) = {} < max(sigmas)"
- ).format(para.name, limit, new_prob)
- warn(errmsg)
- break
+ if bound_reached and new_prob < max(self.probs):
+ errmsg = (f'Bound reached with prob({para.name}={limit}) '
+ f'= {new_prob} < max(sigmas)')
+ warn(errmsg)
+ break
if i > self.maxiter:
- errmsg = f"maxiter={self.maxiter} reached "
- errmsg += ("and prob({}={}) = {} < "
- "max(sigmas).".format(para.name, limit, new_prob))
+ errmsg = (f'maxiter={self.maxiter} reached and prob('
+ f'{para.name}={limit}) = {new_prob} < max(sigmas)')
warn(errmsg)
break
if rel_change < self.min_rel_change:
- errmsg = "rel_change={} < {} ".format(rel_change,
- self.min_rel_change)
- errmsg += ("at iteration {} and prob({}={}) = {} < max"
- "(sigmas).".format(i, para.name, limit, new_prob))
+ errmsg = (f'rel_change={rel_change} < {self.min_rel_change} '
+ f'at iteration {i} and prob({para.name}={limit}) = '
+ f'{new_prob} < max(sigmas)')
warn(errmsg)
break