Created
July 2, 2021 14:48
-
-
Save stephanmg/784b6df00f906a56dd08f6ee44cb9a1c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Amici.py | |
def check_gradients(self, x: np.ndarray = None, *args, **kwargs): | |
if x is None and 'petab_problem' in dir(self.amici_object_builder): | |
x = self.amici_object_builder.petab_problem.x_nominal_scaled | |
x_free = self.amici_object_builder.petab_problem.x_free_indices | |
return ObjectiveBase.check_gradients(self, x=x, x_free=x_free, *args, **kwargs) | |
# Base.py | |
def check_gradients( | |
self, | |
*args, | |
x: np.ndarray = None, | |
x_free: Sequence[int] = None, | |
rtol: float = 1e-2, | |
atol: float = 1e-3, | |
mode: Literal = None, | |
multi_eps=None, | |
**kwargs, | |
) -> bool: | |
"""Check if gradients match finite differences (FDs) | |
Parameters | |
---------- | |
rtol: relative error tolerance | |
atol: absolute error tolerance | |
mode: function values or residuals | |
objAbsoluteTolerance: absolute tolerance in sensitivity calculation | |
objRelativeTolerance: relative tolerance in sensitivity calculation | |
multi_eps: multiple test step width for FDs | |
Returns | |
------- | |
bool | |
Indicates whether gradients match (True) FDs or not (False) | |
""" | |
par = np.asarray(x) | |
free_indices = par[x_free] | |
dfs = [] | |
modes = [] | |
if mode is None: | |
modes = [MODE_FUN, MODE_RES] | |
else: | |
modes = [mode] | |
if multi_eps is None: | |
multi_eps = np.array([10**(-i) for i in range(3, 9)]) | |
for mode in modes: | |
try: | |
dfs.append(self.check_grad_multi_eps( | |
free_indices, *args, **kwargs, | |
mode=mode, multi_eps=multi_eps)) | |
except (RuntimeError, ValueError): | |
# Might happen in case PEtab problem not well defined or | |
# fails for specified tolerances in forward sensitivities | |
return False | |
return all([ | |
any([ | |
np.all((mode_df.rel_err.values < rtol) | | |
(mode_df.abs_err.values < atol)), | |
]) | |
for mode_df in dfs | |
]) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment