Last active
April 16, 2019 20:46
-
-
Save willtebbutt/39205ab845b22e6452a42705eac8d254 to your computer and use it in GitHub Desktop.
Toy tape-based reverse-mode AD with minimal Cassette usage.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# This uses the Nabla.jl-style interception mechanism whereby | |
# we wrap things that are to be differentiated w.r.t. in a | |
# thin wrapper. There are lots of thing that you can't | |
# propoagate derivative information through with this kind of | |
# approach without quite a lot of extra machinery, but the | |
# examples at the bottom do work. | |
# | |
using ChainRules, Cassette | |
using Cassette: @context | |
using ChainRules: rrule, extern, Zero | |
############################## | |
# Types for tracking objects # | |
############################## | |
abstract type Node{T} end | |
struct Leaf{Ty} <: Node{Ty} | |
y::Ty | |
pos::Int | |
tape::Vector{Any} | |
end | |
struct Branch{Ty, Tf, Txs, TΔxs} <: Node{Ty} | |
y::Ty | |
f::Tf | |
xs::Txs | |
Δxs::TΔxs | |
pos::Int | |
tape::Vector{Any} | |
end | |
# Helper functions. | |
is_tagged(x::Node) = true | |
is_tagged(x) = false | |
untag(x::Node) = x.y | |
untag(x) = x | |
get_tape(x...) = x[findfirst(is_tagged, x)].tape | |
###################################################### | |
# Use Cassette to define the interception mechanisms # | |
###################################################### | |
@context DiffCtx | |
function Cassette.overdub(ctx::DiffCtx, f, x) | |
is_tagged(x) || return f(x) | |
rule = rrule(f, untag(x)) | |
if !(rule isa Nothing) | |
y, Δx = rule | |
tape = x.tape | |
y_br = Branch(y, f, (x,), (Δx,), length(tape) + 1, tape) | |
push!(tape, y_br) | |
return y_br | |
end | |
return Cassette.recurse(ctx, f, x) | |
end | |
function Cassette.overdub(ctx::DiffCtx, f, x...) | |
any(is_tagged, x) || return f(x...) | |
rule = rrule(f, map(untag, x)...) | |
if !(rule isa Nothing) | |
y, Δxs = rule | |
tape = get_tape(x...) | |
y_br = Branch(y, f, x, Δxs, length(tape) + 1, tape) | |
push!(tape, y_br) | |
return y_br | |
end | |
return Cassette.recurse(ctx, f, x...) | |
end | |
############################################# | |
# Implement reverse-mode AD in not many LoC # | |
############################################# | |
function forward(f, x...) | |
tape = Vector{Any}() | |
leaves = map(((n, x),)->Leaf(x, n, tape), enumerate(x)) | |
map(leaf->push!(tape, leaf), leaves) | |
y = Cassette.overdub(DiffCtx(), f, leaves...) | |
return y.y, function(ȳ) | |
back_tape = Vector{Any}(undef, length(y.tape)) | |
fill!(back_tape, Zero()) | |
back_tape[end] = ȳ | |
for n in reverse(eachindex(back_tape)) | |
if tape[n] isa Branch | |
for (p, x) in enumerate(tape[n].xs) | |
if is_tagged(x) | |
back_tape[x.pos] = ChainRules.accumulate( | |
back_tape[x.pos], | |
tape[n].Δxs[p], | |
back_tape[n], | |
) | |
end | |
end | |
end | |
end | |
return (extern.(back_tape[1:length(x)])...,) | |
end | |
end | |
foo(x) = sin(x) + cos(x) | |
y, back = forward(foo, 5.0); | |
back(1) | |
bar(x, y) = sin(x) + cos(y) | |
y, back = forward(bar, 5.0, 4.0); | |
back(1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment