Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
improve testing, remove type instability
  • Loading branch information
jwscook committed Mar 21, 2021
commit a27ec35f3fef52d423e55c9e08abe35678ed6a7b
32 changes: 20 additions & 12 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,14 +245,16 @@ Base.:/(z::Number, w::Dual) = Dual(z/value(w), -z*epsilon(w)/value(w)^2)
Base.:/(z::Dual, x::Number) = Dual(value(z)/x, epsilon(z)/x)

for f in [:(Base.:^), :(NaNMath.pow)]
@eval function ($f)(z::Dual, w::Dual)
if epsilon(w) == 0.0
return $f(z, value(w))
end
@eval function ($f)(z::Dual{T1}, w::Dual{T2}) where {T1, T2}
T = promote_type(T1, T2) # for type stability in ? : statements
val = $f(value(z), value(w))

du = epsilon(z) * value(w) * $f(value(z), value(w) - 1) +
epsilon(w) * $f(value(z), value(w)) * log(value(z))
ezvw = epsilon(z) * value(w) # for using in ? : statement
du1 = iszero(ezvw) ? zero(T) : ezvw * $f(value(z), value(w) - 1)
ew = epsilon(w) # for using in ? : statement
# the float is for type stability because log promotes to floats
du2 = iszero(ew) ? zero(float(T)) : ew * val * log(value(z))
du = du1 + du2

Dual(val, du)
end
Expand All @@ -261,15 +263,21 @@ end
Base.mod(z::Dual, n::Number) = Dual(mod(value(z), n), epsilon(z))

# introduce a boolean !iszero(n) for hard zero behaviour to combat NaNs
pow(z::Dual, n) = Dual(value(z)^n, !iszero(n) * (epsilon(z) * n * value(z)^(n - 1)))
# these last two definitions are needed to fix ambiguity warnings
for T ∈ (:Integer, :Rational, :Number)
@eval Base.:^(z::Dual, n::$T) = pow(z, n)
function pow(z::Dual, n::AbstractFloat)
return Dual(value(z)^n, !iszero(n) * (epsilon(z) * n * value(z)^(n - 1)))
end
function pow(z::Dual{T}, n::Integer) where T
iszero(n) && return Dual(one(T), zero(T)) # avoid DomainError Int^(negative Int)
return Dual(value(z)^n, epsilon(z) * n * value(z)^(n - 1))
end
# these first two definitions are needed to fix ambiguity warnings
for T1 ∈ (:Integer, :Rational, :Number)
@eval Base.:^(z::Dual{T}, n::$T1) where T = pow(z, n)
end


NaNMath.pow(z::Dual, n::Number) = Dual(NaNMath.pow(value(z),n), epsilon(z)*n*NaNMath.pow(value(z),n-1))
NaNMath.pow(z::Number, w::Dual) = Dual(NaNMath.pow(z,value(w)), epsilon(w)*NaNMath.pow(z,value(w))*log(z))
NaNMath.pow(z::Dual{T}, n::Number) where T = Dual(NaNMath.pow(value(z),n), epsilon(z)*n*NaNMath.pow(value(z),n-1))
NaNMath.pow(z::Number, w::Dual{T}) where T = Dual(NaNMath.pow(z,value(w)), epsilon(w)*NaNMath.pow(z,value(w))*log(z))

Base.inv(z::Dual) = dual(inv(value(z)),-epsilon(z)/value(z)^2)

Expand Down
32 changes: 25 additions & 7 deletions test/automatic_differentiation_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,39 @@ end

# acting on floats works as expected
for (y, n) ∈ ((float(x), Dual(0.0, 1)), -1:1)
@test float(x)^n == float(x)^float(n)
@test float(y)^n == float(y)^float(n)
end

@test !isnan(epsilon(Dual(0, 1)^1))
@test Dual(0, 1)^1 == Dual(0, 1)

# power_by_squaring error for integers
powwrap(z, n) = Dual(z, 1)^n # needs to be wrapped to make n a literal
@test_throws DomainError powwrap(123, 0)
# needs to be wrapped to make n a literal
powwrap(z, n, epspart=0) = Dual(z, epspart)^n
@test_throws DomainError powwrap(0, -1)
@test_throws DomainError powwrap(2, -1)
@test_throws DomainError powwrap(123, -1)
# this one doesn't error!
@test powwrap(1, -1) == Dual(1, -1)
@test_throws DomainError powwrap(123, -1) # etc
# these ones don't DomainError
@test powwrap(0, 0, 0) == Dual(1, 0) # special case is handled
@test powwrap(0, 0, 1) == Dual(1, 0) # special case is handled
@test powwrap(1, 0) == Dual(1, 1)
@test powwrap(123, 0) == Dual(1, 1)
for i ∈ -3:3
@test powwrap(1, i) == Dual(1, i)
end

@test !isnan(epsilon(Dual(0, 1)^1))
# this no longer throws 1/0 DomainError
@test powwrap(0, Dual(0, 1)) == Dual(1, 0)
# this never did DomainError because it starts off with a float
@test 0.0^Dual(0, 1) == Dual(1.0, NaN)
# and Dual^Dual uses a log and is now type stable
# because the log promotes ints to floats for all values
@test typeof(value(powwrap(0, Dual(0, 1)))) == Float64
@test Dual(0, 1)^Dual(0, 1) == Dual(1, 0)

y = Dual(2.0, 1)^UInt64(0)
@test !isnan(epsilon(y))
@test epsilon(y) == 0

y = sin(x)+exp(x)
@test value(y) ≈ sin(2)+exp(2)
Expand Down