Last active
July 30, 2021 11:19
-
-
Save johnnychen94/b0bf2c336bc6991cf31d81dbb2f86f85 to your computer and use it in GitHub Desktop.
diffwarp: proof of concept on rotation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using ImageTransformations | |
using StaticArrays | |
using Interpolations | |
using ImageCore | |
using ImageShow | |
using TestImages | |
using ChainRules | |
using ChainRules: NoTangent, ZeroTangent, @not_implemented | |
using ChainRulesTestUtils | |
using Zygote | |
img = Float64.(imresize(testimage("cameraman"), (32, 32))) | |
function ϕ(p, θ) | |
sinθ, cosθ = sincos(θ) | |
[p[1]*cosθ - p[2]*sinθ, | |
p[1]*sinθ + p[2]*cosθ] | |
end | |
function ChainRules.rrule(::typeof(ϕ), p, θ) | |
sinθ, cosθ = sincos(θ) | |
q = [p[1]*cosθ - p[2]*sinθ, | |
p[1]*sinθ + p[2]*cosθ] | |
function dϕ(dLdq) | |
dϕdθ = [-p[1]*sinθ - p[2]*cosθ, | |
p[1]*cosθ - p[2]*sinθ] | |
dLdθ = sum(dLdq .* dϕdθ) | |
dLdp = [dLdq[1] * cosθ + dLdq[2] * sinθ, | |
-dLdq[1] * sinθ + dLdq[2] * cosθ] | |
return NoTangent(), dLdp, dLdθ | |
end | |
return q, dϕ | |
end | |
function τ(X, q) | |
etp = extrapolate(interpolate(X, BSpline(Linear())), zero(eltype(X))) | |
return etp(q...) | |
end | |
function ChainRules.rrule(::typeof(τ), X, q) | |
etp = extrapolate(interpolate(X, BSpline(Linear())), zero(eltype(X))) | |
Yp = etp(q...) | |
function dτ(dLdYp) | |
dLdq = dLdYp .* Interpolations.gradient(etp, q...) | |
dLdX = @not_implemented( | |
"Interpolations doesn't yet support gradient to coefficients" | |
) | |
return NoTangent(), dLdX, dLdq | |
end | |
return Yp, dτ | |
end | |
# Do some pixel level tests | |
θ = 0.2 | |
p = [10, 10] | |
q = ϕ(p, θ) | |
# test our rrule | |
test_rrule(ϕ, p, θ; check_inferred=false) # FIXME | |
test_rrule(τ, img, q; check_inferred=false) | |
function f_single_pixel(θ) | |
τ(img, ϕ(p, θ)) | |
end | |
img_p = img[p...] | |
Zygote.gradient(θ) do θ | |
f_single_pixel(θ) - img_p | |
end # (-0.08925456466517724,) | |
# Now let's put the warp together | |
function simple_rotate(X, θ) | |
out = similar(X) | |
for p in CartesianIndices(out) | |
q = ϕ(collect(p.I), θ) | |
out[p] = τ(X, q) | |
end | |
out | |
end | |
function ChainRules.rrule(::typeof(simple_rotate), X, θ) | |
Y = simple_rotate(X, θ) | |
function gradient_simple_rotate(dLdY) | |
dLdθ = zero(eltype(dLdY)) | |
lk = ReentrantLock() | |
Threads.@threads for p in CartesianIndices(Y) | |
tmp = let p = collect(p.I) | |
_, dτ = rrule(τ, X, p) | |
_, _, dLdq = dτ(dLdY[p]) | |
_, dϕ = rrule(ϕ, p, θ) | |
_, _, dLdθ = dϕ(dLdq) | |
return dLdθ | |
end | |
lock(lk) do | |
dLdθ += tmp | |
end | |
end | |
dLdX = @not_implemented( | |
"Interpolations doesn't yet support gradient to coefficients" | |
) | |
return NoTangent(), dLdX, dLdθ | |
end | |
return Y, gradient_simple_rotate | |
end | |
simple_rotate(img, θ) .|> Gray | |
imgr, gradient_simple_rotate = rrule(simple_rotate, img, θ) | |
gradient_simple_rotate(ones(eltype(img), size(img))) | |
test_rrule(simple_rotate, img, θ; check_inferred=false) # FIXME | |
# do some simple experiment | |
g(θ) = simple_rotate(img, θ) | |
g(θ) .|> Gray | |
θ = 0.5 | |
lr = 1e-1 | |
outs = [] | |
for _ in 1:100 | |
dθ, = Zygote.gradient(θ) do θ | |
sum(abs2, g(θ) - img) | |
end | |
θ = θ - lr * dθ | |
out = g(θ) | |
println("loss: ", sum(abs2, out - img), ", dθ=", dθ) | |
push!(outs, Gray.(out)) | |
end | |
ImageShow.gif([outs...]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment