Skip to content
Open
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
2d70966
Upgrade FiniteDifferences
willtebbutt Aug 22, 2021
f530868
Initial implementation
willtebbutt Aug 22, 2021
31f67d3
Some more work
willtebbutt Aug 22, 2021
4b554c5
Add testset comment
willtebbutt Aug 28, 2021
a51abe9
Examples and notes
willtebbutt Aug 28, 2021
8205ce6
Some work
willtebbutt Aug 28, 2021
bccece2
Tidy up examples
willtebbutt Aug 28, 2021
569f064
Tidy up PR notes
willtebbutt Aug 28, 2021
f321a97
Provide optimised pullback implementations
willtebbutt Aug 28, 2021
8a06561
Change pullback implementation specifications
willtebbutt Aug 28, 2021
72fc725
Add extra methods to avoid config
willtebbutt Aug 28, 2021
fbaaf86
Fix typo
willtebbutt Aug 28, 2021
313590d
Tweak comment in examples
willtebbutt Aug 28, 2021
321fceb
Update notes
willtebbutt Aug 28, 2021
a15f469
Tweak notes
willtebbutt Aug 28, 2021
22efec3
Fix typo:
willtebbutt Aug 28, 2021
2b0853e
Tweak notes
willtebbutt Aug 28, 2021
73a5fb7
Tweak notes
willtebbutt Aug 28, 2021
ddb3237
Tweak notes
willtebbutt Aug 28, 2021
34e43d6
Clarify notes
willtebbutt Aug 28, 2021
c773045
Clarify notes
willtebbutt Aug 28, 2021
eaaeba7
Clarify notes
willtebbutt Aug 28, 2021
d94514d
Tweak notes
willtebbutt Aug 28, 2021
d571ad1
Tweak notes
willtebbutt Aug 28, 2021
d87d5c4
Tweak examples comments
willtebbutt Aug 28, 2021
a7eb01c
Tidy up + add Fill example
willtebbutt Aug 29, 2021
e691b6a
Add SArray example
willtebbutt Aug 29, 2021
d0b0eb0
Add note on Symmetric restructure
willtebbutt Aug 29, 2021
d6eef62
Add UpperTriangular example with my_mul
willtebbutt Aug 29, 2021
a4ce7d2
Add WoodburyPDMat example
willtebbutt Aug 31, 2021
a880752
Note that natural tangent addition is fine
willtebbutt Sep 1, 2021
6cada13
Add another examples comment
willtebbutt Sep 1, 2021
8f06364
Add Kronecker example and fix WoodburPDMat example
willtebbutt Sep 5, 2021
5647388
Add note on side-effects
willtebbutt Sep 5, 2021
6d9d01f
Update notes.md
willtebbutt Sep 14, 2021
20f7727
Update notes.md
willtebbutt Sep 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Tidy up examples
  • Loading branch information
willtebbutt committed Aug 28, 2021
commit bccece20c6e1c20130910c6a6c18d98b2f71b9c4
241 changes: 85 additions & 156 deletions examples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,12 @@ using FiniteDifferences
using LinearAlgebra
using Zygote

import ChainRulesCore: rrule
import ChainRulesCore: rrule, pullback_of_destructure, pullback_of_restructure

using ChainRulesCore:
pullback_of_destructure,
pullback_of_restructure,
RuleConfig,
wrap_natural_pullback
using ChainRulesCore: RuleConfig, wrap_natural_pullback

# All of the examples here involve new functions (`my_mul` etc) so that it's possible to
# ensure that Zygote's existing adjoints don't get in the way.
# ensure that Zygote's / ChainRules' existing adjoints don't get in the way.

# Example 1: matrix-matrix multiplication.

Expand Down Expand Up @@ -46,25 +42,21 @@ test_approx(dB, dB_fd)



# Example 2: something where the output isn't a matrix.
my_sum(x::AbstractArray) = sum(x)
# pullbacks for `Real`s so that they play nicely with the utility functionality.

function ChainRulesCore.rrule(::typeof(my_sum), x::Array)
my_sum_strict_pullback(dy::Real) = (NoTangent(), dy * ones(size(x)))
return sum(x), my_sum_strict_pullback
end
ChainRulesCore.pullback_of_destructure(config::RuleConfig, x::Real) = identity

function ChainRulesCore.rrule(::typeof(my_sum), x::AbstractArray)
x_dense, destructure_pb = ChainRulesCore.rrule(destructure, x)
y, my_sum_strict_pullback = ChainRulesCore.rrule(my_sum, x_dense)
ChainRulesCore.pullback_of_restructure(config::RuleConfig, x::Real) = identity

function my_sum_generic_pullback(dy::Real)
_, dx_dense = my_sum_strict_pullback(dy)
_, dx = destructure_pb(dx_dense)
return NoTangent(), dx
end

return y, my_sum_generic_pullback
# Example 2: something where the output isn't a matrix.

my_sum(x::AbstractArray) = sum(x)

function ChainRulesCore.rrule(config::RuleConfig, ::typeof(my_sum), x::AbstractArray)
y = my_sum(x)
natural_pullback_my_sum(ȳ::Real) = NoTangent(), fill(ȳ, size(x))
return y, wrap_natural_pullback(config, natural_pullback_my_sum, y, x)
end

A = Symmetric(randn(2, 2))
Expand All @@ -89,36 +81,59 @@ test_approx(dA, dA_fd)

my_scale(a::Real, x::AbstractMatrix) = a * x

function ChainRulesCore.rrule(::typeof(my_inv), x::Matrix)
function ChainRulesCore.rrule(
config::RuleConfig, ::typeof(my_scale), a::Real, x::AbstractMatrix,
)
y = my_scale(a, x)
natural_pullback_my_scale(ȳ::AbstractMatrix) = NoTangent(), dot(ȳ, x), a * ȳ
return y, wrap_natural_pullback(config, natural_pullback_my_scale, y, a, x)
end

y, pb = ChainRulesCore.rrule(inv, x)
# DENSE TEST
a = randn()
x = randn(2, 2)
y, pb = Zygote.pullback(my_scale, a, x)

# We know that a * x isa Array. Any AbstractArray is an okay tangent for an Array.
function my_scale_pullback(ȳ::AbstractArray)
return NoTangent(), dot(ȳ, x), ȳ * a
end
return a * x, my_scale_pullback
end
dy = randn(size(y))
da, dx = pb(dy)

da_fd, dx_fd = FiniteDifferences.j′vp(central_fdm(5, 1), my_scale, dy, a, x)

test_approx(y, my_scale(a, x))
test_approx(da, da_fd)
test_approx(dx, dx_fd)

function ChainRulesCore.rrule(::typeof(my_scale), a::Real, x::AbstractMatrix)
x_dense, destructure_x_pb = ChainRulesCore.rrule(destructure, x)
y_dense, my_scale_strict_pb = ChainRulesCore.rrule(my_scale, a, x_dense)
y = my_scale(a, x)
y_reconstruct, restructure_pb = ChainRulesCore.rrule(Restructure(y), y_dense)

function my_scale_generic_pullback(dy)
_, dy_dense = restructure_pb(dy)
_, da, dx_dense = my_scale_strict_pb(dy_dense)
_, dx = destructure_x_pb(dx_dense)
return NoTangent(), da, dx
end

return y_reconstruct, my_scale_generic_pullback
# DIAGONAL TEST

# `diag` now returns a `Diagonal` as a tangnet, so have to define `my_diag` to make this
# work with Diagonal`s.
my_diag(x) = diag(x)
function ChainRulesCore.rrule(::typeof(my_diag), D::P) where {P<:Diagonal}
my_diag_pullback(d) = NoTangent(), Tangent{P}(diag=d)
return diag(D), my_diag_pullback
end

Zygote.refresh()
a = randn()
x = Diagonal(randn(2))
y, pb = Zygote.pullback(my_diag ∘ my_scale, a, x)

ȳ = randn(2)
ā, x̄_zg = pb(ȳ)
x̄ = Tangent{typeof(x)}(diag=x̄_zg.diag)

ā_fd, _x̄_fd = FiniteDifferences.j′vp(central_fdm(5, 1), my_diag ∘ my_scale, ȳ, a, x)
x̄_fd = Tangent{typeof(x)}(diag=_x̄_fd.diag)

test_approx(y, (my_diag ∘ my_scale)(a, x))
test_approx(ā, ā_fd)
test_approx(x̄, x̄_fd)

# SYMMETRIC TEST


# SYMMETRIC TEST - FAILS BECAUSE HIDDEN ELEMENTS IN LOWER-DIAGONAL ACCESSED IN PRIMAL!
# I would be surprised if we're doing this consistently at the minute though.

a = randn()
x = Symmetric(randn(2, 2))
Expand All @@ -135,31 +150,12 @@ test_approx(y.data, my_scale(a, x).data)
test_approx(da, da_fd)
test_approx(dx, dx_fd)

# DENSE TEST
x_dense = collect(x)
y, pb = Zygote.pullback(my_scale, a, x_dense)

dy = randn(size(y))
da, dx = pb(dy)

da_fd, dx_fd = FiniteDifferences.j′vp(central_fdm(5, 1), my_scale, dy, a, x_dense)

test_approx(y, my_scale(a, x_dense))
test_approx(da, da_fd)
test_approx(dx, dx_fd)





# Example 4: ScaledVector

using ChainRulesCore
using ChainRulesCore: Restructure, destructure, Restructure
using ChainRulesTestUtils
using FiniteDifferences
using LinearAlgebra
using Zygote
# Example 4: ScaledVector. This is an interesting example because I truly had no idea how to
# specify a natural tangent for this before.

# Implement AbstractArray interface.
struct ScaledMatrix <: AbstractMatrix{Float64}
Expand All @@ -172,13 +168,29 @@ Base.getindex(x::ScaledMatrix, p::Int, q::Int) = x.α * x.v[p, q]
Base.size(x::ScaledMatrix) = size(x.v)


# Implement destructure and restructure.
# Implement destructure and restructure pullbacks.

function pullback_of_destructure(config::RuleConfig, x::P) where {P<:ScaledMatrix}
function pullback_destructure_ScaledMatrix(X̄::AbstractArray)
return Tangent{P}(v = X̄ * x.α, α = dot(X̄, x.v))
end
return pullback_destructure_ScaledMatrix
end

function pullback_of_restructure(config::RuleConfig, x::ScaledMatrix)
function pullback_restructure_ScaledMatrix(x̄::Tangent)
return x̄.v / x.α
end
return pullback_restructure_ScaledMatrix
end

ChainRulesCore.destructure(x::ScaledMatrix) = x.α * x.v
# What destructure and restructure would look like if implemented. pullbacks were derived
# based on these.
# ChainRulesCore.destructure(x::ScaledMatrix) = x.α * x.v

ChainRulesCore.Restructure(x::P) where {P<:ScaledMatrix} = Restructure{P, Float64}(x.α)
# ChainRulesCore.Restructure(x::P) where {P<:ScaledMatrix} = Restructure{P, Float64}(x.α)

(r::Restructure{<:ScaledMatrix})(x::AbstractArray) = ScaledMatrix(x ./ r.data, r.data)
# (r::Restructure{<:ScaledMatrix})(x::AbstractArray) = ScaledMatrix(x ./ r.data, r.data)



Expand All @@ -190,17 +202,9 @@ my_dot(x::AbstractArray, y::AbstractArray) = dot(x, y)
function ChainRulesCore.rrule(
config::RuleConfig, ::typeof(my_dot), x::AbstractArray, y::AbstractArray,
)
_, destructure_x_pb = rrule_via_ad(config, destructure, x)
_, destructure_y_pb = rrule_via_ad(config, destructure, y)

function pullback_my_dot(z̄::Real)
x̄_dense = z̄ * y
ȳ_dense = z̄ * x
_, x̄ = destructure_x_pb(x̄_dense)
_, ȳ = destructure_y_pb(ȳ_dense)
return NoTangent(), x̄, ȳ
end
return my_dot(x, y), pullback_my_dot
z = my_dot(x, y)
natural_pullback_my_dot(z̄::Real) = NoTangent(), z̄ * y, z̄ * x
return z, wrap_natural_pullback(config, natural_pullback_my_dot, z, x, y)
end


Expand All @@ -220,29 +224,9 @@ dx_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_scal, z̄, V, α)
test_approx(dx_ad, dx_fd)


# A function with a specialised rule for ScaledMatrix.
my_scale(a::Real, X::AbstractArray) = a * X
# A function with a specialised method for ScaledMatrix.
my_scale(a::Real, X::ScaledMatrix) = ScaledMatrix(X.v, X.α * a)

# Generic rrule.
function ChainRulesCore.rrule(
config::RuleConfig, ::typeof(my_scale), a::Real, X::AbstractArray,
)
_, destructure_X_pb = rrule_via_ad(config, destructure, X)
Y = my_scale(a, X)
_, restructure_Y_pb = rrule_via_ad(config, Restructure(Y), collect(Y))

function pullback_my_scale(Ȳ)
_, Ȳ_dense = restructure_Y_pb(Ȳ)
ā = dot(Ȳ_dense, X)
X̄_dense = Ȳ_dense * a
_, X̄ = destructure_X_pb(X̄_dense)
return NoTangent(), ā, X̄
end

return Y, pullback_my_scale
end

# Verify correctness.
a = randn()
V = randn(2, 2)
Expand All @@ -261,58 +245,3 @@ da_fd, dV_fd, dα_fd = FiniteDifferences.j′vp(central_fdm(5, 1), foo_my_scale,
test_approx(da, da_fd)
test_approx(dV, dV_fd)
test_approx(dα, dα_fd)





# Utility functionality.

# This will often make life really easy. Just requires that pullback_of_restructure is
# defined for C, and pullback_of_destructure for A and B. Could be generalised to make
# different assumptions (e.g. some arguments don't require destructuring, output doesn't
# require restructuring, etc). Would need to be generalised to arbitrary numbers of
# arguments (clearly doable -- at worst requires a generated function).
function wrap_natural_pullback(natural_pullback, C, A, B)

# Generate enclosing pullbacks. Notice that C / A / B only appear here, and aren't
# part of the closure returned. This means that they don't need to be carried around,
# which is good.
destructure_A_pb = pullback_of_destructure(A)
destructure_B_pb = pullback_of_destructure(B)
restructure_C_pb = pullback_of_restructure(C)

# Wrap natural_pullback to make it play nicely with AD.
function generic_pullback(C̄)
_, C̄_natural = restructure_C_pb(C̄)
f̄, Ā_natural, B̄_natural = natural_pullback(C̄_natural)
_, Ā = destructure_A_pb(Ā_natural)
_, B̄ = destructure_B_pb(B̄_natural)
return f̄, Ā, B̄
end
return generic_pullback
end

# Sketch of rrule for my_mul making use of utility functionality.
function rrule(::typeof(my_mul), A::AbstractMatrix, B::AbstractMatrix)

# Do the primal computation.
C = A * B

# "natural pullback"
function my_mul_natural_pullback(C̄_natural)
Ā_natural = C̄_natural * B'
B̄_natural = A' * C̄_natural
return NoTangent(), Ā_natural, B̄_natural
end

return C, wrap_natural_pullback(my_mul_natural_pullback, C, A, B)
end



# Order in which to present stuff.
# 1. Fully worked-through example (matrix-matrix) multiplication:
# a. Most stupid implementation.
# b. Optimal manual implementation.
# c. Optimal implementation using utility functionality.