Created
November 14, 2024 18:33
-
-
Save mducle/becf9f475489a9b2baadf661eb12a92c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/+sw_tests/+unit_tests/unittest_ndbase_optimisers.m b/+sw_tests/+unit_tests/unittest_ndbase_optimisers.m | |
index 5f6d9f5..9ec3ffe 100644 | |
--- a/+sw_tests/+unit_tests/unittest_ndbase_optimisers.m | |
+++ b/+sw_tests/+unit_tests/unittest_ndbase_optimisers.m | |
@@ -59,25 +59,25 @@ classdef unittest_ndbase_optimisers < sw_tests.unit_tests.unittest_super | |
function test_optimise_rosen_upper_bound_minimum_not_accessible(testCase, optimiser) | |
[pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1], 'ub', [0, inf]); | |
testCase.verify_val(pars_fit, [0, 0], 'abs_tol', 1e-3); | |
- testCase.verify_val(cost_val, 1, 'abs_tol', 2e-3); | |
+ testCase.verify_val(cost_val, 1, 'abs_tol', 1e-4); | |
end | |
function test_optimise_rosen_both_bounds_minimum_accessible(testCase, optimiser) | |
- [pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1], 'lb', [-5, -5], 'ub', [5, 5]); | |
+ [pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1], 'lb', [-2, -2], 'ub', [2, 2]); | |
testCase.verify_val(pars_fit, testCase.rosenbrock_minimum, 'abs_tol', 1e-3); | |
testCase.verify_val(cost_val, 0, 'abs_tol', 1e-6); | |
end | |
- function test_optimise_rosen_both_bounds_minimum_not_accessible(testCase) | |
+ function test_optimise_rosen_both_bounds_minimum_not_accessible(testCase, optimiser) | |
% note intital guess is outside bounds | |
- [pars_fit, cost_val, ~] = ndbase.simplex([], testCase.rosenbrock, [-1,-1], 'lb', [-0.5, -0.5], 'ub', [0, 0]); | |
+ [pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1], 'lb', [-0.5, -0.5], 'ub', [0, 0]); | |
testCase.verify_val(pars_fit, [0, 0], 'abs_tol', 1e-3); | |
testCase.verify_val(cost_val, 1, 'abs_tol', 1e-6); | |
end | |
- function test_optimise_rosen_parameter_fixed_minimum_not_accessible(testCase) | |
+ function test_optimise_rosen_parameter_fixed_minimum_not_accessible(testCase, optimiser) | |
% note intital guess is outside bounds | |
- [pars_fit, cost_val, ~] = ndbase.simplex([], testCase.rosenbrock, [-1,-1], 'lb', [0, -0.5], 'ub', [0, 0]); | |
+ [pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1], 'lb', [0, -0.5], 'ub', [0, 0]); | |
testCase.verify_val(pars_fit, [0, 0], 'abs_tol', 1e-3); | |
testCase.verify_val(cost_val, 1, 'abs_tol', 1e-6); | |
end | |
diff --git a/swfiles/+ndbase/cost_function_wrapper.m b/swfiles/+ndbase/cost_function_wrapper.m | |
index a94d391..eda44c0 100644 | |
--- a/swfiles/+ndbase/cost_function_wrapper.m | |
+++ b/swfiles/+ndbase/cost_function_wrapper.m | |
@@ -72,6 +72,11 @@ classdef cost_function_wrapper < handle & matlab.mixin.SetGet | |
fix_tol = 1e-10 | |
end | |
+ properties (Access = private) | |
+ ub | |
+ lb | |
+ end | |
+ | |
methods | |
function obj = cost_function_wrapper(fhandle, params, options) | |
arguments | |
@@ -110,13 +115,21 @@ classdef cost_function_wrapper < handle & matlab.mixin.SetGet | |
% validate size of bounds | |
lb = options.lb; | |
ub = options.ub; | |
- if ~isempty(lb) && numel(lb) ~= numel(params) | |
- error("ndbase:cost_function_wrapper:WrongInput", ... | |
- "Lower bounds must be empty or have same size as parameter vector."); | |
+ obj.lb = -inf * ones(size(params)); | |
+ obj.ub = inf * ones(size(params)); | |
+ if ~isempty(lb) | |
+ if numel(lb) ~= numel(params) | |
+ error("ndbase:cost_function_wrapper:WrongInput", ... | |
+ "Lower bounds must be empty or have same size as parameter vector."); | |
+ end | |
+ obj.lb = lb; | |
end | |
- if ~isempty(ub) && numel(ub) ~= numel(params) | |
- error("ndbase:cost_function_wrapper:WrongInput", ... | |
- "Upper bounds must be empty or have same size as parameter vector."); | |
+ if ~isempty(ub) | |
+ if numel(ub) ~= numel(params) | |
+ error("ndbase:cost_function_wrapper:WrongInput", ... | |
+ "Upper bounds must be empty or have same size as parameter vector."); | |
+ end | |
+ obj.ub = ub; | |
end | |
if ~isempty(lb) && ~isempty(ub) && any(ub<lb) | |
error("ndbase:cost_function_wrapper:WrongInput", ... | |
@@ -191,6 +204,19 @@ classdef cost_function_wrapper < handle & matlab.mixin.SetGet | |
if isempty(obj.bound_to_free_funcs{ipar}) | |
pars(ipar) = pars_bound(ipar); | |
else | |
+ if pars_bound(ipar) <= obj.lb(ipar) | |
+ if isfinite(obj.ub(ipar)) | |
+ pars_bound(ipar) = (obj.ub(ipar) + obj.lb(ipar)) / 2; | |
+ else | |
+ pars_bound(ipar) = pars_bound(ipar) + max(abs(obj.lb(ipar)), 1) / 2; | |
+ end | |
+ elseif pars_bound(ipar) >= obj.ub(ipar) | |
+ if isfinite(obj.lb(ipar)) | |
+ pars_bound(ipar) = (obj.ub(ipar) + obj.lb(ipar)) / 2; | |
+ else | |
+ pars_bound(ipar) = pars_bound(ipar) - max(abs(obj.ub(ipar)), 1) / 2; | |
+ end | |
+ end | |
pars(ipar) = obj.bound_to_free_funcs{ipar}(pars_bound(ipar)); | |
end | |
end | |
diff --git a/swfiles/+ndbase/lm4.m b/swfiles/+ndbase/lm4.m | |
index 5e55776..d321daa 100644 | |
--- a/swfiles/+ndbase/lm4.m | |
+++ b/swfiles/+ndbase/lm4.m | |
@@ -98,7 +98,7 @@ inpForm.defval = [inpForm.defval {1e-8 1e-8 1e-8 1e-2}]; | |
inpForm.size = [inpForm.size {[1 1] [1 1] [1 1] [1 1]}]; | |
inpForm.fname = [inpForm.fname {'nu_up', 'nu_dn', 'resid_handle'}]; | |
-inpForm.defval = [inpForm.defval {10 0.3, false}]; | |
+inpForm.defval = [inpForm.defval {5 0.3, false}]; | |
inpForm.size = [inpForm.size {[1 1] [1 1], [1 1]}]; | |
param = sw_readparam(inpForm, varargin{:}); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment