Last active
March 18, 2025 21:14
-
-
Save kiranshila/8d1d506bb523baaa3dba787ed307f2ab to your computer and use it in GitHub Desktop.
Rust/Enzyme example with NLOpt FFI
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//! Following https://nlopt.readthedocs.io/en/latest/NLopt_Tutorial/ | |
#![feature(autodiff)] | |
use nlopt::*; | |
use std::autodiff::autodiff; | |
#[autodiff(df, Reverse, Duplicated, Active)] | |
fn f(x: &[f64]) -> f64 { | |
x[1].sqrt() | |
} | |
struct ConsParam { | |
a: f64, | |
b: f64, | |
} | |
#[autodiff(dcons, Reverse, Duplicated, Const, Active)] | |
fn cons(x: &[f64], p: &ConsParam) -> f64 { | |
let inner = p.a * x[0] + p.b; | |
(inner * inner * inner) - x[1] | |
} | |
fn nlopt_obj(x: &[f64], gradient: Option<&mut [f64]>, _p: &mut ()) -> f64 { | |
if let Some(grad) = gradient { | |
grad.fill(0.0); | |
df(x, grad, 1.0) | |
} else { | |
f(x) | |
} | |
} | |
fn nlopt_cons(x: &[f64], gradient: Option<&mut [f64]>, p: &mut ConsParam) -> f64 { | |
if let Some(grad) = gradient { | |
grad.fill(0.0); | |
dcons(x, grad, p, 1.0) | |
} else { | |
cons(x, p) | |
} | |
} | |
fn main() -> Result<(), FailState> { | |
let mut opt = Nlopt::new(Algorithm::Mma, 2, nlopt_obj, Target::Minimize, ()); | |
// Set the bounds | |
opt.set_lower_bounds(&[f64::NEG_INFINITY, 0.0])?; | |
// Set the constraints | |
opt.add_inequality_constraint(nlopt_cons, ConsParam { a: 2.0, b: 0.0 }, 1.0e-8)?; | |
opt.add_inequality_constraint(nlopt_cons, ConsParam { a: -1.0, b: 1.0 }, 1.0e-8)?; | |
// Set the tolerance limit | |
opt.set_xtol_rel(1.0e-4)?; | |
// Set the initial guess | |
let mut x = [1.234, 5.678]; | |
// Solve with timing | |
let earlier = std::time::Instant::now(); | |
let res = opt.optimize(&mut x); | |
let dur = std::time::Instant::now().duration_since(earlier); | |
println!("Result: {:?}", res); | |
println!("X vals: {:?}", &x[..]); | |
println!("Opt time: {:?}", dur); | |
Ok(()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment