Skip to content
Open
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
2d70966
Upgrade FiniteDifferences
willtebbutt Aug 22, 2021
f530868
Initial implementation
willtebbutt Aug 22, 2021
31f67d3
Some more work
willtebbutt Aug 22, 2021
4b554c5
Add testset comment
willtebbutt Aug 28, 2021
a51abe9
Examples and notes
willtebbutt Aug 28, 2021
8205ce6
Some work
willtebbutt Aug 28, 2021
bccece2
Tidy up examples
willtebbutt Aug 28, 2021
569f064
Tidy up PR notes
willtebbutt Aug 28, 2021
f321a97
Provide optimised pullback implementations
willtebbutt Aug 28, 2021
8a06561
Change pullback implementation specifications
willtebbutt Aug 28, 2021
72fc725
Add extra methods to avoid config
willtebbutt Aug 28, 2021
fbaaf86
Fix typo
willtebbutt Aug 28, 2021
313590d
Tweak comment in examples
willtebbutt Aug 28, 2021
321fceb
Update notes
willtebbutt Aug 28, 2021
a15f469
Tweak notes
willtebbutt Aug 28, 2021
22efec3
Fix typo:
willtebbutt Aug 28, 2021
2b0853e
Tweak notes
willtebbutt Aug 28, 2021
73a5fb7
Tweak notes
willtebbutt Aug 28, 2021
ddb3237
Tweak notes
willtebbutt Aug 28, 2021
34e43d6
Clarify notes
willtebbutt Aug 28, 2021
c773045
Clarify notes
willtebbutt Aug 28, 2021
eaaeba7
Clarify notes
willtebbutt Aug 28, 2021
d94514d
Tweak notes
willtebbutt Aug 28, 2021
d571ad1
Tweak notes
willtebbutt Aug 28, 2021
d87d5c4
Tweak examples comments
willtebbutt Aug 28, 2021
a7eb01c
Tidy up + add Fill example
willtebbutt Aug 29, 2021
e691b6a
Add SArray example
willtebbutt Aug 29, 2021
d0b0eb0
Add note on Symmetric restructure
willtebbutt Aug 29, 2021
d6eef62
Add UpperTriangular example with my_mul
willtebbutt Aug 29, 2021
a4ce7d2
Add WoodburyPDMat example
willtebbutt Aug 31, 2021
a880752
Note that natural tangent addition is fine
willtebbutt Sep 1, 2021
6cada13
Add another examples comment
willtebbutt Sep 1, 2021
8f06364
Add Kronecker example and fix WoodburPDMat example
willtebbutt Sep 5, 2021
5647388
Add note on side-effects
willtebbutt Sep 5, 2021
6d9d01f
Update notes.md
willtebbutt Sep 14, 2021
20f7727
Update notes.md
willtebbutt Sep 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Tidy up + add Fill example
  • Loading branch information
willtebbutt committed Aug 29, 2021
commit a7eb01cf46e94199a2ed8a2f7a018d831725e0b2
268 changes: 186 additions & 82 deletions examples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,26 @@ function rrule(config::RuleConfig, ::typeof(my_mul), A::AbstractMatrix, B::Abstr
return C, wrap_natural_pullback(config, natural_pullback_for_mul, C, A, B)
end

A = randn(4, 3);
B = Symmetric(randn(3, 3));
C, pb = Zygote.pullback(my_mul, A, B);
let
A = randn(4, 3);
B = Symmetric(randn(3, 3));
C, pb = Zygote.pullback(my_mul, A, B);

@assert C ≈ my_mul(A, B)
@assert C ≈ my_mul(A, B)

dC = randn(4, 3);
dA, dB_zg = pb(dC);
dB = Tangent{typeof(B)}(data=dB_zg.data);
dC = randn(4, 3);
dA, dB_zg = pb(dC);
dB = Tangent{typeof(B)}(data=dB_zg.data);

# Test correctness.
dA_fd, dB_fd_sym = FiniteDifferences.j′vp(central_fdm(5, 1), my_mul, dC, A, B);
# Test correctness.
dA_fd, dB_fd_sym = FiniteDifferences.j′vp(central_fdm(5, 1), my_mul, dC, A, B);

# to_vec doesn't know how to make `Tangent`s, so instead I map it to a `Tangent` manually.
dB_fd = Tangent{typeof(B)}(data=dB_fd_sym.data);
# to_vec doesn't know how to make `Tangent`s, so translate manually.
dB_fd = Tangent{typeof(B)}(data=dB_fd_sym.data);

test_approx(dA, dA_fd)
test_approx(dB, dB_fd)
test_approx(dA, dA_fd)
test_approx(dB, dB_fd)
end



Expand All @@ -59,19 +61,21 @@ function ChainRulesCore.rrule(config::RuleConfig, ::typeof(my_sum), x::AbstractA
return y, wrap_natural_pullback(config, natural_pullback_my_sum, y, x)
end

A = Symmetric(randn(2, 2))
y, pb = Zygote.pullback(my_sum, A)
let
A = Symmetric(randn(2, 2))
y, pb = Zygote.pullback(my_sum, A)

test_approx(y, my_sum(A))
test_approx(y, my_sum(A))

dy = randn()
dA_zg, = pb(dy)
dA = Tangent{typeof(A)}(data=dA_zg.data)
dy = randn()
dA_zg, = pb(dy)
dA = Tangent{typeof(A)}(data=dA_zg.data)

dA_fd_sym, = FiniteDifferences.j′vp(central_fdm(5, 1), my_sum, dy, A)
dA_fd = Tangent{typeof(A)}(data=dA_fd_sym.data)
dA_fd_sym, = FiniteDifferences.j′vp(central_fdm(5, 1), my_sum, dy, A)
dA_fd = Tangent{typeof(A)}(data=dA_fd_sym.data)

test_approx(dA, dA_fd)
test_approx(dA, dA_fd)
end



Expand All @@ -90,66 +94,70 @@ function ChainRulesCore.rrule(
end

# DENSE TEST
a = randn()
x = randn(2, 2)
y, pb = Zygote.pullback(my_scale, a, x)

dy = randn(size(y))
da, dx = pb(dy)
let
a = randn()
x = randn(2, 2)
y, pb = Zygote.pullback(my_scale, a, x)

da_fd, dx_fd = FiniteDifferences.j′vp(central_fdm(5, 1), my_scale, dy, a, x)
dy = randn(size(y))
da, dx = pb(dy)

test_approx(y, my_scale(a, x))
test_approx(da, da_fd)
test_approx(dx, dx_fd)
da_fd, dx_fd = FiniteDifferences.j′vp(central_fdm(5, 1), my_scale, dy, a, x)

test_approx(y, my_scale(a, x))
test_approx(da, da_fd)
test_approx(dx, dx_fd)
end


# DIAGONAL TEST

# `diag` now returns a `Diagonal` as a tangnet, so have to define `my_diag` to make this
# work with Diagonal`s.

my_diag(x) = diag(x)
function ChainRulesCore.rrule(::typeof(my_diag), D::P) where {P<:Diagonal}
my_diag_pullback(d) = NoTangent(), Tangent{P}(diag=d)
return diag(D), my_diag_pullback
end

a = randn()
x = Diagonal(randn(2))
y, pb = Zygote.pullback(my_diag ∘ my_scale, a, x)
let
a = randn()
x = Diagonal(randn(2))
y, pb = Zygote.pullback(my_diag ∘ my_scale, a, x)

ȳ = randn(2)
ā, x̄_zg = pb(ȳ)
x̄ = Tangent{typeof(x)}(diag=x̄_zg.diag)
ȳ = randn(2)
ā, x̄_zg = pb(ȳ)
x̄ = Tangent{typeof(x)}(diag=x̄_zg.diag)

ā_fd, _x̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), my_diag ∘ my_scale, ȳ, a, x)
x̄_fd = Tangent{typeof(x)}(diag=_x̄_fd.diag)
ā_fd, _x̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), my_diag ∘ my_scale, ȳ, a, x)
x̄_fd = Tangent{typeof(x)}(diag=_x̄_fd.diag)

test_approx(y, (my_diag ∘ my_scale)(a, x))
test_approx(ā, ā_fd)
test_approx(x̄, x̄_fd)
test_approx(y, (my_diag ∘ my_scale)(a, x))
test_approx(ā, ā_fd)
test_approx(x̄, x̄_fd)

end


# SYMMETRIC TEST - FAILS BECAUSE PRIVATE ELEMENTS IN LOWER-DIAGONAL ACCESSED IN PRIMAL!
# I would be surprised if we're doing this consistently at the minute though.
let
a = randn()
x = Symmetric(randn(2, 2))
y, pb = Zygote.pullback(my_scale, a, x)

a = randn()
x = Symmetric(randn(2, 2))
y, pb = Zygote.pullback(my_scale, a, x)
dy = Tangent{typeof(y)}(data=randn(2, 2))
da, dx_zg = pb(dy)
dx = Tangent{typeof(x)}(data=dx_zg.data)

dy = Tangent{typeof(y)}(data=randn(2, 2))
da, dx_zg = pb(dy)
dx = Tangent{typeof(x)}(data=dx_zg.data)

da_fd, dx_fd_sym = FiniteDifferences.j′vp(central_fdm(5, 1), my_scale, dy, a, x)
dx_fd = Tangent{typeof(x)}(data=dx_fd_sym.data)

test_approx(y.data, my_scale(a, x).data)
test_approx(da, da_fd)
test_approx(dx, dx_fd)
da_fd, dx_fd_sym = FiniteDifferences.j′vp(central_fdm(5, 1), my_scale, dy, a, x)
dx_fd = Tangent{typeof(x)}(data=dx_fd_sym.data)

test_approx(y.data, my_scale(a, x).data)
test_approx(da, da_fd)
test_approx(dx, dx_fd)
end



Expand Down Expand Up @@ -207,41 +215,137 @@ function ChainRulesCore.rrule(
return z, wrap_natural_pullback(config, natural_pullback_my_dot, z, x, y)
end

let
# Check correctness of `my_dot` rrule. Build `ScaledMatrix` internally to avoid
# technical issues with FiniteDifferences.
V = randn(2, 2)
α = randn()
z̄ = randn()

# Check correctness of `my_dot` rrule. Build `ScaledMatrix` internally to avoid technical
# issues with FiniteDifferences.
V = randn(2, 2)
α = randn()
z̄ = randn()
foo_scal(V, α) = my_dot(ScaledMatrix(V, α), V)

foo_scal(V, α) = my_dot(ScaledMatrix(V, α), V)
z, pb = Zygote.pullback(foo_scal, V, α)
dx_ad = pb(z̄)

z, pb = Zygote.pullback(foo_scal, V, α)
dx_ad = pb(z̄)
dx_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_scal, z̄, V, α)

dx_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_scal, z̄, V, α)

test_approx(dx_ad, dx_fd)
test_approx(dx_ad, dx_fd)
end


# A function with a specialised method for ScaledMatrix.
# Specialised method of my_scale for ScaledMatrix
my_scale(a::Real, X::ScaledMatrix) = ScaledMatrix(X.v, X.α * a)

# Verify correctness.
a = randn()
V = randn(2, 2)
α = randn()
z̄ = randn()
let
# Verify correctness.
a = randn()
V = randn(2, 2)
α = randn()
z̄ = randn()

# A more complicated programme involving `my_scale`.
B = randn(2, 2)
foo_my_scale(a, V, α) = my_dot(B, my_scale(a, ScaledMatrix(V, α)))
# A more complicated programme involving `my_scale`.
B = randn(2, 2)
foo_my_scale(a, V, α) = my_dot(B, my_scale(a, ScaledMatrix(V, α)))

z, pb = Zygote.pullback(foo_my_scale, a, V, α)
da, dV, dα = pb(z̄)
z, pb = Zygote.pullback(foo_my_scale, a, V, α)
da, dV, dα = pb(z̄)

da_fd, dV_fd, dα_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_my_scale, z̄, a, V, α)
da_fd, dV_fd, dα_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_my_scale, z̄, a, V, α)

test_approx(da, da_fd)
test_approx(dV, dV_fd)
test_approx(dα, dα_fd)
end




# Example 5: Fill

using FillArrays

# What you would implement:
# destucture(x::Fill) = collect(x)

function pullback_of_destructure(config::RuleConfig, x::P) where {P<:Fill}
pullback_destructure_Fill(X̄::AbstractArray) = Tangent{P}(value=sum(X̄))
return pullback_destructure_Fill
end

test_approx(da, da_fd)
test_approx(dV, dV_fd)
test_approx(dα, dα_fd)
# There are multiple equivalent choices for Restructure here. I present two options below,
# both yield the correct answer.
# To understand why there is a range of options here, recall that the input to
# (::Restructure)(x::AbstractArray) must be an array that is equal (in the `getindex` sense)
# to a `Fill`. Moreover, `(::Restructure)` simply has to promise that the `Fill` it outputs
# is equal to the `Fill` from which the `Restructure` is constructed (in the structural
# sense). Since the elements of `x` must all be equal, any affine combination will do (
# weighted sum, whose weights sum to 1).
# While there is no difference in the primal for `(::Restructure)`, the pullback is quite
# different, depending upon your choice. Since we don't ever want to evaluate the primal for
# `(::Restructure)`, just the pullback, we are free to choose whatever definition of
# `(::Restructure)` makes its pullback pleasant. In particular, defining `(::Restructure)`
# to take the mean of its argument yields a pleasant pullback (see below).

# Restucture option 1:

# Restructure(x::P) where {P<:Fill} = Restructure{P, typeof(x.axes)}(x.axes)
# (r::Restructure{<:Fill})(x::AbstractArray) = Fill(x[1], r.data)

function pullback_of_restructure(config::RuleConfig, x::Fill)
println("Option 1")
function pullback_restructure_Fill(x̄::Tangent)
X̄ = zeros(size(x))
X̄[1] = x̄.value
return X̄
end
return pullback_restructure_Fill
end

# Restructure option 2:

# Restructure(x::P) where {P<:Fill} = Restructure{P, typeof(x.axes)}(x.axes)
# (r::Restructure{<:Fill})(x::AbstractArray) = Fill(mean(x), r.data)

function pullback_of_restructure(config::RuleConfig, x::Fill)
println("Option 2")
pullback_restructure_Fill(x̄::Tangent) = Fill(x̄.value / length(x), x.axes)
return pullback_restructure_Fill
end


let
A = randn(2, 3)
v = randn()

# Build the Fill inside because FiniteDifferenes doesn't play nicely with Fills, even
# if one adds a `to_vec` call.
foo_my_mul(A, v) = my_mul(A, Fill(v, 3, 4))
C, pb = Zygote.pullback(foo_my_mul, A, v)

@assert C ≈ foo_my_mul(A, v)

C̄ = randn(2, 4);
Ā, v̄ = pb(C̄);

Ā_fd, v̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_my_mul, C̄, A, v);

test_approx(Ā, Ā_fd)
test_approx(v̄, v̄_fd)
end

let
foo_my_scale(a, v) = my_sum(my_scale(a, Fill(v, 3, 4)))
a = randn()
v = randn()

c, pb = Zygote.pullback(foo_my_scale, a, v)
@assert c ≈ foo_my_scale(a, v)

c̄ = randn()
ā, v̄ = pb(c̄)

ā_fd, v̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_my_scale, c̄, a, v);

test_approx(ā, ā_fd)
test_approx(v̄, v̄_fd)
end