Last active
January 10, 2024 10:36
-
-
Save ramboldio/925aa3eeea505a0ddf3221cf4b5e94c1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# You will find a step-by-step guide to this example in the docs and the | |
# corresponding jupyter notebook on our github repository. | |
using DifferentialEquations | |
using LightGraphs | |
using NetworkDynamics | |
using Plots | |
using LinearAlgebra | |
using DiffEqFlux, DiffEqSensitivity, Flux, OrdinaryDiffEq, Zygote | |
### Defining a graph | |
N = 10 # number of nodes | |
k = 3 # average degree | |
g = barabasi_albert(N, k, seed=43) # a little more exciting than a bare random graph | |
c = 10000.0 | |
d = 100.0 | |
m = 100.0 | |
s_unstreched = 0.2 | |
grav = [-9.81, 0.0, 0.0] | |
### Functions for edges and vertices | |
# the edge and vertex functions are mutating, hence | |
# by convention `!` is appended to their names | |
unpack_displacement = v -> [v[1], v[2], v[3]] | |
unpack_velocity = v -> [v[4], v[5], v[6]] | |
# e=edges, v_s=source-vertices, v_d=destination-vertices, p=parameters, t=time | |
function diffusionedge!(e, v_s, v_d, p, t) | |
# usually e, v_s, v_d are arrays, hence we use the broadcasting operator. | |
r_source = unpack_displacement(v_s) | |
r_dest = unpack_displacement(v_d) | |
v_source = unpack_velocity(v_s) | |
v_dest = unpack_velocity(v_d) | |
c_ = p | |
r_vec = r_source - r_dest | |
# spring_force = (r_vec + (s_unstreched / norm(r_vec)) .* r_vec) * c_ | |
# simple (& faster) case where default spring length is 0 | |
spring_force = r_vec * c_ | |
damping_force = ((v_source - v_dest) * d) | |
e .= spring_force + damping_force | |
nothing | |
end | |
# dv=derivative of vertex variables, v=vertex, | |
# e_s=source edges, e_d=destination edges, p=parameters, t=time | |
function diffusionvertex!(dv, v, e_s, e_d, p, t) | |
r_displacement = unpack_displacement(v) | |
v_velocity = unpack_velocity(v) | |
fsum = (array) -> reduce((acc, elem) -> acc + elem, array, init=zeros(3)) | |
a_acceleration = ((fsum(e_d) - fsum(e_s)) / m) + grav | |
dv .= [v_velocity..., a_acceleration...] | |
# dv[2] = acceleration[1] | |
nothing | |
end | |
a = zeros(6) | |
diffusionvertex!(a, zeros(6), [], [], nothing, 0) | |
### Constructing the network dynamics | |
# ODEVertex and StaticEdge are structs that contain additional information on the function f!, such as number of variables of the internal function (dim), the symbols of those variables, and if a mass_matrix should be used | |
# VertexFunction/EdgeFunction is an abstract supertype for all vertex/edge function structs in NetworkDynamics | |
# signature of ODEVertex: (f!, dim, mass_matrix, sym) (mass_matrix and sym are optional arguments) | |
nd_diffusion_vertex = ODEVertex(f! = diffusionvertex!, dim = 6) | |
nd_anchor = StaticVertex(f! = f! = (θ, e_s, e_d, c, t) -> θ .= zeros(6), dim=6) | |
# signature of StaticEdge: (f!, dim, sym) (sym is optional) | |
nd_diffusion_edge = StaticEdge(f! = diffusionedge!, dim = 3) | |
# setting up the key constructor network_dynamics | |
# signature of network_dynamics: (vertices!, edges!, g; parallel = false) | |
# parameter parallel of type bool enables a parallel environment | |
# returned object nd of type ODEFunction is compatible with the solvers of OrdinaryDiffEq | |
nd_vertecies = Array{VertexFunction}([nd_diffusion_vertex for x in range(1, stop=nv(g))]) | |
nd_vertecies[1] = nd_anchor | |
nd_edges = [nd_diffusion_edge for x in range(1, stop=ne(g))] | |
nd = network_dynamics(nd_vertecies, nd_edges, g, parallel=true) | |
nd_wrapper = (dx, x, p, t) -> nd(dx, x, (nothing, p), t) | |
### Simulation | |
x0 = rand(N*6) # random initial conditions | |
# ODEProblem is a struct of OrdinaryDiffEq.jl | |
# signature:(f::ODEFunction,u0,tspan, ...) | |
ode_prob = ODEProblem(nd_wrapper, x0, (0., 5.), c) | |
# solve has signature: (prob::PDEProblem,alg::DiffEqBase.DEAlgorithm,args,kwargs) | |
@time sol = solve(ode_prob, Rodas3(), save_at=0.1); | |
### Plotting Trajectories | |
# vars=list of variables we want to plot, in this case we want to plot variables with symbol "v" | |
plot(sol[6, :], sol[7, :], sol[8, :]) | |
plot!(sol[12, :], sol[13, :], sol[14, :]) | |
plot!(sol[18, :], sol[19, :], sol[20, :]) | |
### DiffEqFlux Parameter Optimization | |
p = c | |
function predict_rd() | |
Array(solve(ode_prob,Rodas5(),saveat=0.1,reltol=1e-4)) | |
end | |
loss_rd() = sum(abs2,x-1 for x in predict_rd()) | |
loss_rd() | |
opt = ADAM(0.1) | |
cb = function () | |
display(loss_rd()) | |
#display(plot(solve(remake(prob,p=p),Tsit5(),saveat=0.1),ylim=(0,6))) | |
end | |
Flux.train!(loss_rd, Flux.params(p), Iterators.repeated((), 100), opt, cb = cb) | |
### DiffEqFlux Surrogate Solving | |
datasize = 30 | |
tspan = (0.0f0,1.5f0) | |
t = range(tspan[1],tspan[2],length=datasize) | |
ode_data = Array(solve(ode_prob,Rodas3(),saveat=t)) | |
dudt2 = Chain(x -> x.^3, | |
Dense(2,50,tanh), | |
Dense(50,2)) | |
p,re = Flux.destructure(dudt2) # use this p as the initial condition! | |
dudt(u,p,t) = re(p)(u) # need to restrcture for backprop! | |
function predict_n_ode() | |
Array(solve(ode_prob,Rodas5(),saveat=t)) | |
end | |
function loss_n_ode() | |
pred = predict_n_ode() | |
loss = sum(abs2,ode_data .- pred) | |
loss | |
end | |
loss_n_ode() # n_ode.p stores the initial parameters of the neural ODE | |
cb = function (;doplot=false) #callback function to observe training | |
pred = predict_n_ode() | |
display(sum(abs2,ode_data .- pred)) | |
# plot current prediction against data | |
pl = scatter(t,ode_data[1,:],label="data") | |
scatter!(pl,t,pred[1,:],label="prediction") | |
display(plot(pl)) | |
return false | |
end | |
# Display the ODE with the initial parameter values. | |
cb() | |
data = Iterators.repeated((), 1000) | |
Flux.train!(loss_n_ode, Flux.params(x0,p), data, ADAM(0.05), cb = cb) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Current issue (when running the first occurrence of
Flux.train
):