Created
August 10, 2021 12:16
-
-
Save oxinabox/c6ad25c468b3108f8a799bda66c147f8 to your computer and use it in GitHub Desktop.
Sketch: Extension of rrule to take in the activity (i.e. if your want to get the derivative wrt this)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
abstract type ActivityMarked{T} end | |
struct Active{T} <: ActivityMarked{T} | |
val::T | |
end | |
struct Dead{T} <: ActivityMarked{T} | |
val::T | |
end | |
active(x) = false | |
active(x::Active) = true | |
strip_activity(x::ActivityMarked) = x.val | |
""" | |
`active_rrule` is like `rrule` but rather than passing in primals, you pass in either `Active(primal)` or `Dead(primal)` depending on if you want the be able to AD wrt it. | |
If it is Dead wthen for it's deriviate we return `NoTangent` (in examples that follow) | |
or perhaps some new `DidNotRequestTangent<:AbstractZero` | |
""" | |
function active_rrule end | |
# Fallback | |
# if in doubt all back to assuming all are active | |
# and using a plain rrule | |
active_rrule(f, args...) = rrule(f, strip_activity.(args)...) | |
# if all are dead then this is easy | |
all_dead_pullback(n) = _->ntuple(_->NoTangent(), n+1) | |
function active_rrule(f, args::Dead...) | |
@assert fieldcount(f) === 0 # ignoring functors for now | |
return f(args...), all_dead_pullback(length(args)) | |
end | |
# Now define: | |
function active_rrule(::typeof(foo), a::Active, b) | |
# slow way to get cotangent for both a and b | |
end | |
function active_rrule(::typeof(foo), a::Dead, b::Active) | |
# fast way to get cotangent for b, and a returns NoTangent() | |
end |
To reduce ambiguities also should have a all active cast that hits rrule
Not sure what you mean by that. Could you elaborate please?
We could define:
active_rrule(args::Active...) = rrule(strip_activity.(args)...)
So that if everything is active if will just all the rrule
, no matter what the function is.
Ahhh I see. Yeah, that makes sense.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
👍