Created
August 10, 2021 12:16
-
-
Save oxinabox/c6ad25c468b3108f8a799bda66c147f8 to your computer and use it in GitHub Desktop.
Sketch: Extension of rrule to take in the activity (i.e. if your want to get the derivative wrt this)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
abstract type ActivityMarked{T} end | |
struct Active{T} <: ActivityMarked{T} | |
val::T | |
end | |
struct Dead{T} <: ActivityMarked{T} | |
val::T | |
end | |
active(x) = false | |
active(x::Active) = true | |
strip_activity(x::ActivityMarked) = x.val | |
""" | |
`active_rrule` is like `rrule` but rather than passing in primals, you pass in either `Active(primal)` or `Dead(primal)` depending on if you want the be able to AD wrt it. | |
If it is Dead wthen for it's deriviate we return `NoTangent` (in examples that follow) | |
or perhaps some new `DidNotRequestTangent<:AbstractZero` | |
""" | |
function active_rrule end | |
# Fallback | |
# if in doubt all back to assuming all are active | |
# and using a plain rrule | |
active_rrule(f, args...) = rrule(f, strip_activity.(args)...) | |
# if all are dead then this is easy | |
all_dead_pullback(n) = _->ntuple(_->NoTangent(), n+1) | |
function active_rrule(f, args::Dead...) | |
@assert fieldcount(f) === 0 # ignoring functors for now | |
return f(args...), all_dead_pullback(length(args)) | |
end | |
# Now define: | |
function active_rrule(::typeof(foo), a::Active, b) | |
# slow way to get cotangent for both a and b | |
end | |
function active_rrule(::typeof(foo), a::Dead, b::Active) | |
# fast way to get cotangent for b, and a returns NoTangent() | |
end |
Probably the fallback to rrule
should also drop any tangents for anything that is Dead
.
So that if they were thunks it can for sure avoid anyone ever unthunking them.
👍
To reduce ambiguities also should have a all active cast that hits rrule
Not sure what you mean by that. Could you elaborate please?
We could define:
active_rrule(args::Active...) = rrule(strip_activity.(args)...)
So that if everything is active if will just all the rrule
, no matter what the function is.
Ahhh I see. Yeah, that makes sense.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Yeah, this seems like a nice solution. It's particularly nice in that it would be opt-in for AD systems -- if they don't care about activity, and just want regular
rrule
s, they can avoid this infrastructure entirely.It would also be really easy to test these -- you just check for consistency with the regular
rrule
(probably we'd want to insist thatactive_rrule
s have a fallbackrrule
for the all-active case to prevent accidentally createactive_rrule
s withoutrrule
s, which would prevent Zygote and Diffractor from using them).