Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
removed NaNs, codified behaviour with Ints
  • Loading branch information
jwscook committed Mar 19, 2021
commit 3634f5dabbc2027b529e9af1fe2e57d4c99b9b1b
12 changes: 7 additions & 5 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,14 @@ end

Base.mod(z::Dual, n::Number) = Dual(mod(value(z), n), epsilon(z))

# these two definitions are needed to fix ambiguity warnings
Base.:^(z::Dual, n::Unsigned) = z^Signed(n)
Base.:^(z::Dual, n::Integer) = Dual(value(z)^n, epsilon(z)*n*value(z)^(n-1))
Base.:^(z::Dual, n::Rational) = Dual(value(z)^n, epsilon(z)*n*value(z)^(n-1))
# introduce a boolean !iszero(n) for hard zero behaviour to combat NaNs
pow(z::Dual, n) = Dual(value(z)^n, !iszero(n) * (epsilon(z) * n * value(z)^(n - 1)))
# these last two definitions are needed to fix ambiguity warnings
for T ∈ (:Integer, :Rational, :Number)
@eval Base.:^(z::Dual, n::$T) = pow(z, n)
end


Base.:^(z::Dual, n::Number) = Dual(value(z)^n, epsilon(z)*n*value(z)^(n-1))
NaNMath.pow(z::Dual, n::Number) = Dual(NaNMath.pow(value(z),n), epsilon(z)*n*NaNMath.pow(value(z),n-1))
NaNMath.pow(z::Number, w::Dual) = Dual(NaNMath.pow(z,value(w)), epsilon(w)*NaNMath.pow(z,value(w))*log(z))

Expand Down
23 changes: 23 additions & 0 deletions test/automatic_differentiation_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,29 @@ y = x^3.0
@test value(y) ≈ 2.0^3
@test epsilon(y) ≈ 3.0*2^2

# taking care with divides by zero where there shouldn't be any on paper
for (y, n) ∈ Iterators.product((float(x), Dual(0.0, 1)), (0, 0.0))
z = y^n
@test value(z) == 1
@test !isnan(epsilon(z))
@test epsilon(z) == 0
end

# acting on floats works as expected
for (y, n) ∈ ((float(x), Dual(0.0, 1)), -1:1)
@test float(x)^n == float(x)^float(n)
end

# power_by_squaring error for integers
powwrap(z, n) = Dual(z, 1)^n # needs to be wrapped to make n a literal
@test_throws DomainError powwrap(123, 0)
@test_throws DomainError powwrap(2, -1)
@test_throws DomainError powwrap(123, -1)
# this one doesn't error!
@test powwrap(1, -1) == Dual(1, -1)

@test !isnan(epsilon(Dual(0, 1)^1))

y = Dual(2.0, 1)^UInt64(0)
@test !isnan(epsilon(y))

Expand Down