Skip to content

Instantly share code, notes, and snippets.

@mducle
Created November 14, 2024 18:33
Show Gist options
  • Save mducle/becf9f475489a9b2baadf661eb12a92c to your computer and use it in GitHub Desktop.
Save mducle/becf9f475489a9b2baadf661eb12a92c to your computer and use it in GitHub Desktop.
diff --git a/+sw_tests/+unit_tests/unittest_ndbase_optimisers.m b/+sw_tests/+unit_tests/unittest_ndbase_optimisers.m
index 5f6d9f5..9ec3ffe 100644
--- a/+sw_tests/+unit_tests/unittest_ndbase_optimisers.m
+++ b/+sw_tests/+unit_tests/unittest_ndbase_optimisers.m
@@ -59,25 +59,25 @@ classdef unittest_ndbase_optimisers < sw_tests.unit_tests.unittest_super
function test_optimise_rosen_upper_bound_minimum_not_accessible(testCase, optimiser)
[pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1], 'ub', [0, inf]);
testCase.verify_val(pars_fit, [0, 0], 'abs_tol', 1e-3);
- testCase.verify_val(cost_val, 1, 'abs_tol', 2e-3);
+ testCase.verify_val(cost_val, 1, 'abs_tol', 1e-4);
end
function test_optimise_rosen_both_bounds_minimum_accessible(testCase, optimiser)
- [pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1], 'lb', [-5, -5], 'ub', [5, 5]);
+ [pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1], 'lb', [-2, -2], 'ub', [2, 2]);
testCase.verify_val(pars_fit, testCase.rosenbrock_minimum, 'abs_tol', 1e-3);
testCase.verify_val(cost_val, 0, 'abs_tol', 1e-6);
end
- function test_optimise_rosen_both_bounds_minimum_not_accessible(testCase)
+ function test_optimise_rosen_both_bounds_minimum_not_accessible(testCase, optimiser)
% note intital guess is outside bounds
- [pars_fit, cost_val, ~] = ndbase.simplex([], testCase.rosenbrock, [-1,-1], 'lb', [-0.5, -0.5], 'ub', [0, 0]);
+ [pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1], 'lb', [-0.5, -0.5], 'ub', [0, 0]);
testCase.verify_val(pars_fit, [0, 0], 'abs_tol', 1e-3);
testCase.verify_val(cost_val, 1, 'abs_tol', 1e-6);
end
- function test_optimise_rosen_parameter_fixed_minimum_not_accessible(testCase)
+ function test_optimise_rosen_parameter_fixed_minimum_not_accessible(testCase, optimiser)
% note intital guess is outside bounds
- [pars_fit, cost_val, ~] = ndbase.simplex([], testCase.rosenbrock, [-1,-1], 'lb', [0, -0.5], 'ub', [0, 0]);
+ [pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1], 'lb', [0, -0.5], 'ub', [0, 0]);
testCase.verify_val(pars_fit, [0, 0], 'abs_tol', 1e-3);
testCase.verify_val(cost_val, 1, 'abs_tol', 1e-6);
end
diff --git a/swfiles/+ndbase/cost_function_wrapper.m b/swfiles/+ndbase/cost_function_wrapper.m
index a94d391..eda44c0 100644
--- a/swfiles/+ndbase/cost_function_wrapper.m
+++ b/swfiles/+ndbase/cost_function_wrapper.m
@@ -72,6 +72,11 @@ classdef cost_function_wrapper < handle & matlab.mixin.SetGet
fix_tol = 1e-10
end
+ properties (Access = private)
+ ub
+ lb
+ end
+
methods
function obj = cost_function_wrapper(fhandle, params, options)
arguments
@@ -110,13 +115,21 @@ classdef cost_function_wrapper < handle & matlab.mixin.SetGet
% validate size of bounds
lb = options.lb;
ub = options.ub;
- if ~isempty(lb) && numel(lb) ~= numel(params)
- error("ndbase:cost_function_wrapper:WrongInput", ...
- "Lower bounds must be empty or have same size as parameter vector.");
+ obj.lb = -inf * ones(size(params));
+ obj.ub = inf * ones(size(params));
+ if ~isempty(lb)
+ if numel(lb) ~= numel(params)
+ error("ndbase:cost_function_wrapper:WrongInput", ...
+ "Lower bounds must be empty or have same size as parameter vector.");
+ end
+ obj.lb = lb;
end
- if ~isempty(ub) && numel(ub) ~= numel(params)
- error("ndbase:cost_function_wrapper:WrongInput", ...
- "Upper bounds must be empty or have same size as parameter vector.");
+ if ~isempty(ub)
+ if numel(ub) ~= numel(params)
+ error("ndbase:cost_function_wrapper:WrongInput", ...
+ "Upper bounds must be empty or have same size as parameter vector.");
+ end
+ obj.ub = ub;
end
if ~isempty(lb) && ~isempty(ub) && any(ub<lb)
error("ndbase:cost_function_wrapper:WrongInput", ...
@@ -191,6 +204,19 @@ classdef cost_function_wrapper < handle & matlab.mixin.SetGet
if isempty(obj.bound_to_free_funcs{ipar})
pars(ipar) = pars_bound(ipar);
else
+ if pars_bound(ipar) <= obj.lb(ipar)
+ if isfinite(obj.ub(ipar))
+ pars_bound(ipar) = (obj.ub(ipar) + obj.lb(ipar)) / 2;
+ else
+ pars_bound(ipar) = pars_bound(ipar) + max(abs(obj.lb(ipar)), 1) / 2;
+ end
+ elseif pars_bound(ipar) >= obj.ub(ipar)
+ if isfinite(obj.lb(ipar))
+ pars_bound(ipar) = (obj.ub(ipar) + obj.lb(ipar)) / 2;
+ else
+ pars_bound(ipar) = pars_bound(ipar) - max(abs(obj.ub(ipar)), 1) / 2;
+ end
+ end
pars(ipar) = obj.bound_to_free_funcs{ipar}(pars_bound(ipar));
end
end
diff --git a/swfiles/+ndbase/lm4.m b/swfiles/+ndbase/lm4.m
index 5e55776..d321daa 100644
--- a/swfiles/+ndbase/lm4.m
+++ b/swfiles/+ndbase/lm4.m
@@ -98,7 +98,7 @@ inpForm.defval = [inpForm.defval {1e-8 1e-8 1e-8 1e-2}];
inpForm.size = [inpForm.size {[1 1] [1 1] [1 1] [1 1]}];
inpForm.fname = [inpForm.fname {'nu_up', 'nu_dn', 'resid_handle'}];
-inpForm.defval = [inpForm.defval {10 0.3, false}];
+inpForm.defval = [inpForm.defval {5 0.3, false}];
inpForm.size = [inpForm.size {[1 1] [1 1], [1 1]}];
param = sw_readparam(inpForm, varargin{:});
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment