From 8ee80d104cc2fbb52dfb1cbdfa38475e7a06f8f8 Mon Sep 17 00:00:00 2001 From: Kristoffer Carlsson Date: Sat, 18 Mar 2017 21:00:51 +0100 Subject: [PATCH 1/2] fix wrong partials multiplied in FMA --- src/dual.jl | 2 +- test/DualTest.jl | 28 +++++++++++++++------------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/dual.jl b/src/dual.jl index dfed4927..213bd4c1 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -420,7 +420,7 @@ end @inline function Base.fma(x::Dual, y::Dual, z::Real) vx, vy = value(x), value(y) result = fma(vx, vy, z) - return Dual(result, _mul_partials(partials(x), partials(y), vx, vy)) + return Dual(result, _mul_partials(partials(x), partials(y), vy, vx)) end @inline function Base.fma(x::Dual, y::Real, z::Dual) diff --git a/test/DualTest.jl b/test/DualTest.jl index ef96341b..a80bdcca 100644 --- a/test/DualTest.jl +++ b/test/DualTest.jl @@ -15,6 +15,8 @@ samerng() = MersenneTwister(1) # exponent by one intrand(T) = T == Int ? rand(2:10) : rand(T) +dualapprox(A, B) = value(A) ≈ value(B) && partials(A) ≈ partials(B) + # fix testing issue with Base.hypot(::Int...) undefined in 0.4 if v"0.4" <= VERSION < v"0.5" Base.hypot(x::Int, y::Int) = Base.hypot(Float64(x), Float64(y)) @@ -387,20 +389,20 @@ for N in (0,3), M in (0,4), T in (Int, Float32) @test partials(NaNMath.pow(Dual(-2.0, 1.0), Dual(2.0, 0.0)), 1) == -4.0 - @test fma(FDNUM, FDNUM2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), + @test dualapprox(fma(FDNUM, FDNUM2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS + - PARTIALS3) - @test fma(FDNUM, FDNUM2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), - PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS) - @test fma(PRIMAL, FDNUM2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), - PRIMAL*PARTIALS2 + PARTIALS3) - @test fma(PRIMAL, FDNUM2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), - PRIMAL*PARTIALS2) - @test fma(FDNUM, PRIMAL2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), - PRIMAL2*PARTIALS + PARTIALS3) - @test fma(FDNUM, PRIMAL2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), - PRIMAL2*PARTIALS) - @test fma(PRIMAL, PRIMAL2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), PARTIALS3) + PARTIALS3)) + @test dualapprox(fma(FDNUM, FDNUM2, PRIMAL3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), + PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS)) + @test dualapprox(fma(PRIMAL, FDNUM2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), + PRIMAL*PARTIALS2 + PARTIALS3)) + @test dualapprox(fma(PRIMAL, FDNUM2, PRIMAL3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), + PRIMAL*PARTIALS2)) + @test dualapprox(fma(FDNUM, PRIMAL2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), + PRIMAL2*PARTIALS + PARTIALS3)) + @test dualapprox(fma(FDNUM, PRIMAL2, PRIMAL3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), + PRIMAL2*PARTIALS)) + @test dualapprox(fma(PRIMAL, PRIMAL2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), PARTIALS3)) # Unary Functions # #-----------------# From 2a6d95aa43c3ebebbddaf588cf34403ba1bd9781 Mon Sep 17 00:00:00 2001 From: Kristoffer Carlsson Date: Sat, 18 Mar 2017 21:16:33 +0100 Subject: [PATCH 2/2] use more fma --- src/dual.jl | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/dual.jl b/src/dual.jl index 213bd4c1..8610bd0f 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -410,11 +410,13 @@ end @ambiguous @inline Base.atan2(y::Real, x::Dual) = calc_atan2(y, x) @ambiguous @inline Base.atan2(y::Dual, x::Real) = calc_atan2(y, x) -@inline function Base.fma(x::Dual, y::Dual, z::Dual) - vx, vy = value(x), value(y) - result = fma(vx, vy, value(z)) - return Dual(result, - _mul_partials(partials(x), partials(y), vx, vy) + partials(z)) +@generated function Base.fma{N}(x::Dual{N}, y::Dual{N}, z::Dual{N}) + ex = Expr(:tuple, [:(fma(value(x), partials(y)[$i], fma(value(y), partials(x)[$i], partials(z)[$i]))) for i in 1:N]...) + return quote + $(Expr(:meta, :inline)) + v = fma(value(x), value(y), value(z)) + Dual(v, $ex) + end end @inline function Base.fma(x::Dual, y::Dual, z::Real) @@ -423,10 +425,13 @@ end return Dual(result, _mul_partials(partials(x), partials(y), vy, vx)) end -@inline function Base.fma(x::Dual, y::Real, z::Dual) - vx = value(x) - result = fma(vx, y, value(z)) - return Dual(result, partials(x) * y + partials(z)) +@generated function Base.fma{N}(x::Dual{N}, y::Real, z::Dual{N}) + ex = Expr(:tuple, [:(fma(partials(x)[$i], y, partials(z)[$i])) for i in 1:N]...) + return quote + $(Expr(:meta, :inline)) + v = fma(value(x), y, value(z)) + Dual(v, $ex) + end end @inline Base.fma(x::Real, y::Dual, z::Dual) = fma(y, x, z)