Last active
April 13, 2020 14:36
-
-
Save mschauer/6f8e0f57ce0b15eef9f86390d0631df4 to your computer and use it in GitHub Desktop.
Reference implementation of backwards filtering, forward guiding with https://arxiv.org/abs/1712.03807
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using LinearAlgebra | |
using Random | |
using GaussianDistributions | |
using GaussianDistributions: logpdf | |
pair(u) = u[1], u[2] | |
pair(p::Gaussian) = p.μ, p.Σ | |
skiplast(r) = r[1:end-1] | |
# time grid | |
dt = 0.01 | |
T = 10.0 | |
s = 0:dt:T | |
# model dX = b(x)dt + σ(x)dW | |
b(x) = -0.1x - 2.5sin(x*2pi) + 0.5 | |
σ(x) = 0.9 | |
# linear approximation of b and constant approximation of σ | |
B = -0.1 | |
β = 0.5 | |
b̃(x) = B*x + β | |
σ̃ = 0.9 | |
# observation times t | |
ti = 1:10:length(s) | |
t = s[ti] | |
# observation scheme Y ∼ N(L*X, σϵ^2) | |
L = 1.0 | |
σϵ = 0.2 | |
Σ = σϵ*σϵ' | |
# Kalman correction step, https://en.wikipedia.org/wiki/Kalman_filter#Update | |
""" | |
correct(u::T, v, H) | |
Correction step of a Kalman filter with `u = (x, P)` the prediction with uncertainty | |
covariance `P`, and `v = (y, R)` the observation with uncertainty covariance `R` | |
and the observation operator `H`. See https://en.wikipedia.org/wiki/Kalman_filter#Update. | |
""" | |
function correct(u, v, H) | |
x, Ppred = pair(u) | |
y, R = pair(v) | |
yres = y - H*x # innovation residual | |
S = (H*Ppred*H' + R) # innovation covariance | |
K = Ppred*H'*inv(S) # Kalman gain | |
x = x + K*yres | |
P = (I - K*H)*Ppred*(I - K*H)' + K*R*K' | |
(x, P), yres, S | |
end | |
# Sample the model | |
""" | |
forwardsample(s, ti, x) | |
Simulate trajectory on timegrid `s` and observations at times `s[ti]` | |
using the Euler-Maruyama scheme. | |
""" | |
function forwardsample(s, ti, x) | |
xs = typeof(x)[] | |
ys = typeof(L*x)[] | |
for i in skiplast(eachindex(s)) | |
if i in ti | |
push!(ys, L*x + σϵ*randn()) | |
end | |
push!(xs, x) | |
x = x + b(x)*dt + σ(x)*sqrt(dt)*randn() | |
end | |
push!(xs, x) | |
if lastindex(s) in ti | |
push!(ys, L*x + σϵ*randn()) | |
end | |
xs, ys | |
end | |
# Compute marginal approximate filtering distributions given data `ys` backwards | |
""" | |
backwardfilter(s, ti, ys, (ν, P)) -> ps, p0 | |
Backward filtering, starting with `N(ν, P)` prior, assuming that ys contains observations | |
at times `t = s[ti]` with `y ∼ N(L X[t], Σ)`. | |
""" | |
function backwardfilter(s, ti, ys, πT) | |
@assert lastindex(s) in ti | |
j = length(ys) | |
p, _ = correct(πT, (ys[j], Σ), L) | |
ps = [p] | |
ν, P = pair(p) | |
for i in eachindex(s)[end-1:-1:1] | |
P = P - dt*(B*P + P*B' - σ̃*σ̃') | |
ν = ν - dt*(B*ν + β) | |
push!(ps, (ν, P)) | |
if i in ti | |
j = j - 1 | |
p, _ = correct((ν, P), (ys[j], Σ), L) | |
(ν, P) = pair(p) | |
end | |
end | |
reverse!(ps), (ν, P) | |
end | |
""" | |
forwardguiding(s, x, ps) -> xs, ll | |
Forward sample a guided trajectory `xs` starting in `x` and compute it's | |
log-likelihood `ll`. | |
""" | |
function forwardguiding(s, x, ps) | |
llstep(x, r, P) = dot(b(x) - b̃(x), r)*dt - 0.5*tr((σ(x)*σ(x)' - σ̃*σ̃')*(inv(P) - r*r'))*dt | |
xs = typeof(x)[] | |
ll = 0.0 | |
for i in skiplast(eachindex(s)) | |
push!(xs, x) | |
ν, P = pair(ps[i]) | |
r = inv(P)*(ν - x) | |
ll += llstep(x, r, P) # accumulate log-likelihood | |
x = x + b(x)*dt + σ(x)*σ(x)'*r*dt + σ(x)*sqrt(dt)*randn() # evolution guided by observations | |
end | |
push!(xs, x) | |
xs, ll | |
end | |
Random.seed!(123) | |
# First generate data from the model for illustration | |
π0 = Gaussian(0.0, 1.0) | |
x0 = rand(π0) | |
xs, ys = forwardsample(s, ti, x0) # sample trajectory | |
# run backwards filter given the observations ys | |
πT = Gaussian(0.0, 10.0) # prior for the backward filter | |
ps, p0 = backwardfilter(s, ti, ys, πT) | |
# sample trajectories and their importance weight | |
K = 10 | |
x̂s = Vector(undef, K) | |
ll = zeros(K) | |
for k in 1:K | |
x0 = rand(Gaussian(p0...)) # sample from p0 | |
x̂s[k], ll[k] = forwardguiding(s, x0, ps) | |
ll[k] += logpdf(π0, x̂s[k][1]) - logpdf(πT, x̂s[k][end]) # correct for having used | |
# backward prior πT instead of | |
# our actual prior π0 | |
end | |
lmax = maximum(exp.(ll)) # maximum of importance weights | |
# Plot samples of the latent trajectories colored according to imporance weight | |
using Plots | |
pl = Plots.scatter(t, ys, color=:orange, markersize=2., label="obs",legend=:outertopright) # observations | |
for k in 1:K | |
Plots.plot!(pl, s, x̂s[k], color=:maroon, lw = 0.6, alpha = exp(ll[k])/lmax, label="sample $k") # samples | |
end | |
Plots.plot!(pl, s, xs, color=:lightseagreen, label="x true") # ground truth | |
display(pl) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using LinearAlgebra | |
using Random | |
using GaussianDistributions | |
using GaussianDistributions: logpdf | |
using Parameters | |
pair(u) = u[1], u[2] | |
pair(p::Gaussian) = p.μ, p.Σ | |
skiplast(r) = r[1:end-1] | |
# model dX = b(x)dt + σ(x)dW | |
# argument M contains Model parameters, see below | |
b(x, M) = -0.1x - M.θ*sin(x*2pi) + 0.5 | |
σ(x, M) = 0.9 | |
# linear approximation of b and constant approximation of σ | |
b̃(x, M) = M.B*x + M.β | |
σ̃(M) = M.σ̃ | |
# time grid | |
dt = 0.01 | |
T = 10.0 | |
s = 0:dt:T | |
# observation times t | |
ti = 1:10:length(s) | |
t = s[ti] | |
@with_kw struct Model{R} @deftype R # in 1d all parameters can be of the same type R | |
# unknown parameter | |
θ = 2.5 | |
# parameters for linear approximation of b and constant approximation of σ | |
B = -0.1 | |
β = 0.5 | |
σ̃ = 0.9 | |
# observation scheme Y ∼ N(L*X, σϵ^2) | |
L = 1.0 | |
σϵ = 0.2 | |
Σ = σϵ*σϵ' | |
end | |
# Kalman correction step, https://en.wikipedia.org/wiki/Kalman_filter#Update | |
""" | |
correct(u::T, v, H) | |
Correction step of a Kalman filter with `u = (x, P)` the prediction with uncertainty | |
covariance `P`, and `v = (y, R)` the observation with uncertainty covariance `R` | |
and the observation operator `H`. See https://en.wikipedia.org/wiki/Kalman_filter#Update. | |
""" | |
function correct(u, v, H, c = 0.0) | |
x, Ppred = pair(u) | |
y, R = pair(v) | |
yres = y - H*x # innovation residual | |
S = (H*Ppred*H' + R) # innovation covariance | |
K = Ppred*H'*inv(S) # Kalman gain | |
x = x + K*yres | |
P = (I - K*H)*Ppred*(I - K*H)' + K*R*K' | |
c = c - logpdf(Gaussian(zero(y), R), y) | |
(x, P), c, yres, S | |
end | |
# Sample the model | |
""" | |
forwardsample(s, ti, x) | |
Simulate trajectory on timegrid `s` and observations at times `s[ti]` | |
using the Euler-Maruyama scheme. | |
""" | |
function forwardsample(M, s, ti, x) | |
@unpack L, σϵ = M | |
xs = typeof(x)[] | |
ys = typeof(L*x)[] | |
for i in skiplast(eachindex(s)) | |
dt = s[i+1] - s[i] | |
if i in ti | |
push!(ys, L*x + σϵ*randn()) | |
end | |
push!(xs, x) | |
x = x + b(x, M)*dt + σ(x, M)*sqrt(dt)*randn() | |
end | |
push!(xs, x) | |
if lastindex(s) in ti | |
push!(ys, L*x + σϵ*randn()) | |
end | |
xs, ys | |
end | |
# Compute marginal approximate filtering distributions given data `ys` backwards | |
""" | |
backwardfilter(M, s, ti, ys, (ν, P)) -> ps, p0, c | |
Backward filtering, starting with `N(ν, P)` prior, assuming that ys contains observations | |
at times `t = s[ti]` with `y ∼ N(L X[t], Σ)`. `exp(-c)` is the integration constant from Theorem 3.3. | |
""" | |
function backwardfilter(M, s, ti, ys, πT, c = 0.0) | |
@unpack L, Σ, B, β, σ̃ = M | |
@assert lastindex(s) in ti | |
j = length(ys) | |
p, _, c = correct(πT, (ys[j], Σ), L, c) | |
ps = [p] | |
ν, P = pair(p) | |
for i in eachindex(s)[end-1:-1:1] | |
dt = s[i+1] - s[i] | |
P = P - dt*(B*P + P*B' - σ̃*σ̃') | |
ν = ν - dt*(B*ν + β) | |
H = inv(P) | |
F = H*ν | |
c += β*F*dt + 0.5*F'*σ̃*σ̃'*F*dt - 0.5*sum(H .* (σ̃*σ̃'))*dt | |
push!(ps, (ν, P)) | |
if i in ti | |
j = j - 1 | |
p, _, c = correct((ν, P), (ys[j], Σ), L, c) | |
(ν, P) = pair(p) | |
end | |
end | |
reverse!(ps), (ν, P), c | |
end | |
""" | |
forwardguiding(M, s, x, ps, Z) -> xs, ll | |
Forward sample a guided trajectory `xs` starting in `x` and compute it's | |
log-likelihood `ll` with innovations `Z = randn(length(s))`. | |
""" | |
function forwardguiding(M, s, x, ps, Z=randn(length(s))) | |
llstep(x, r, P) = dot(b(x, M) - b̃(x, M), r)*dt - 0.5*tr((σ(x, M)*σ(x, M)' - σ̃(M)*σ̃(M)')*(inv(P) - r*r'))*dt | |
xs = typeof(x)[] | |
ll = 0.0 | |
for i in skiplast(eachindex(s)) | |
dt = s[i+1] - s[i] | |
push!(xs, x) | |
ν, P = pair(ps[i]) | |
r = inv(P)*(ν - x) | |
ll += llstep(x, r, P) # accumulate log-likelihood | |
x = x + b(x, M)*dt + σ(x, M)*σ(x, M)'*r*dt + σ(x, M)*sqrt(dt)*Z[i] # evolution guided by observations | |
end | |
push!(xs, x) | |
xs, ll | |
end | |
""" | |
randomwalkmcmc(s, ti, ys, θ0, iters, ρ = 0.9, σθ = 0.01) | |
Infer parameter θ using Metropolis-Hastings with joint update of | |
innovations (Crank Nicolson with parameter ρ) and parameter θ (Gaussian random walk | |
with stepsize σθ) | |
""" | |
function randomwalkmcmc(s, ti, ys, θ0, iters, ρ = 0.9, σθ = 0.01) | |
θ = θ0 | |
Mᵒ = Model(θ = θ) | |
θs = [θ] | |
# sample initial latent path | |
ps, p0, c = backwardfilter(M, s, ti, ys, πT) | |
x = rand(Gaussian(p0...)) | |
Z = randn(length(s)) | |
x̂, ll = forwardguiding(M, s, x, ps, Z) | |
acc = 0 | |
for iter in 1:iters | |
# random walk proposal for parameter | |
θᵒ = θ + σθ* randn() | |
# independent proposal for starting point | |
x0ᵒ = rand(Gaussian(p0...)) | |
# compute filtering density for guiding | |
Mᵒ = Model(θ = θᵒ) | |
ps, p0, c = backwardfilter(Mᵒ, s, ti, ys, πT) | |
ν0, P0 = p0 | |
# random walk proposal for innovations | |
Zᵒ = ρ*Z + sqrt(1 - ρ^2)*randn(length(s)) | |
# compute latent path | |
x̂ᵒ, llᵒ = forwardguiding(Mᵒ, s, x0ᵒ, ps, Zᵒ) | |
llᵒ += logpdf(π0, x̂ᵒ[1]) - logpdf(πT, x̂ᵒ[end]) | |
llᵒ += -c + (-0.5*x0ᵒ' + ν0')*inv(P0)*x0ᵒ # constant may change if σ depends on parameter | |
# Metropolis-Hastings accept/reject for joint proposal of starting point, path, parameter | |
if rand() < exp(llᵒ - ll) | |
θ = θᵒ | |
ll = llᵒ | |
x0 = x0ᵒ | |
x̂ = x̂ᵒ | |
Z = Zᵒ | |
acc += 1 | |
end | |
push!(θs, θ) | |
end | |
θs, acc/iters | |
end | |
Random.seed!(123) | |
# Set true model | |
θtrue = 2.5 | |
M = Model(θ = θtrue) | |
# First generate data from the model for illustration | |
π0 = Gaussian(0.0, 1.0) | |
x0 = rand(π0) | |
xs, ys = forwardsample(M, s, ti, x0) # sample trajectory | |
# run backwards filter given the observations ys | |
πT = Gaussian(0.0, 10.0) # prior for the backward filter | |
ps, p0, c = backwardfilter(M, s, ti, ys, πT) | |
# sample trajectories and their importance weight | |
K = 10 | |
x̂s = Vector(undef, K) | |
ll = zeros(K) | |
for k in 1:K | |
x0 = rand(Gaussian(p0...)) # sample from p0 | |
x̂s[k], ll[k] = forwardguiding(M, s, x0, ps) | |
ll[k] += logpdf(π0, x̂s[k][1]) - logpdf(πT, x̂s[k][end]) # correct for having used | |
# backward prior πT instead of | |
# our actual prior π0 | |
end | |
lmax = maximum(exp.(ll)) # maximum of importance weights | |
# inference for parameter θ | |
θ = 0.2θtrue # start somewhere wrong | |
iters = 50000 | |
ρ = 0.9 # random walk parameter for innovation update (Crank Nicolson scheme) | |
σθ = 0.03 # stepsize randomwalk parameter | |
θs, a = @time randomwalkmcmc(s, ti, ys, θ, iters, ρ, σθ) | |
println("Acceptance rate: ", a) | |
# Plot samples of the latent trajectories colored according to imporance weight | |
using Plots | |
pl = Plots.scatter(t, ys, color=:orange, markersize=2., label="obs",legend=:outertopright) # observations | |
for k in 1:K | |
Plots.plot!(pl, s, x̂s[k], color=:maroon, lw = 0.6, alpha = exp(ll[k])/lmax, label="sample $k") # samples | |
end | |
Plots.plot!(pl, s, xs, color=:lightseagreen, label="x true") # ground truth | |
display(pl) | |
# Plot samples of the mcmc chain for θ | |
pl2 = Plots.plot(0:10:iters, θs[1:10:end], label = "theta, samples") | |
Plots.plot!(pl2, 0:10:iters, fill(θtrue, length(0:10:iters)), label= "theta, true") | |
display(pl2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Just to be sure, in line 190, is this \log \tilde\rho(0,x_0), expressed in Hfc-parametrisation?