From 4b3eb129d7c1dd05c57acd65ea3133deada12b7b Mon Sep 17 00:00:00 2001 From: Kristoffer Carlsson Date: Sat, 18 Mar 2017 21:00:51 +0100 Subject: [PATCH 1/2] fix wrong partials multiplied in FMA --- test/DualTest.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/DualTest.jl b/test/DualTest.jl index 53dc610e..1ede4cad 100644 --- a/test/DualTest.jl +++ b/test/DualTest.jl @@ -15,6 +15,8 @@ samerng() = MersenneTwister(1) # exponent by one intrand(T) = T == Int ? rand(2:10) : rand(T) +dualapprox(A, B) = value(A) ≈ value(B) && partials(A) ≈ partials(B) + # fix testing issue with Base.hypot(::Int...) undefined in 0.4 if v"0.4" <= VERSION < v"0.5" Base.hypot(x::Int, y::Int) = Base.hypot(Float64(x), Float64(y)) From edec71b952de207c12ed95f32c237ca839ed78a4 Mon Sep 17 00:00:00 2001 From: Kristoffer Carlsson Date: Sat, 18 Mar 2017 21:16:33 +0100 Subject: [PATCH 2/2] use more fma --- src/dual.jl | 23 ++++++++++++++--------- test/DualTest.jl | 2 -- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/dual.jl b/src/dual.jl index 0cfb7182..8610bd0f 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -410,11 +410,13 @@ end @ambiguous @inline Base.atan2(y::Real, x::Dual) = calc_atan2(y, x) @ambiguous @inline Base.atan2(y::Dual, x::Real) = calc_atan2(y, x) -@inline function Base.fma(x::Dual, y::Dual, z::Dual) - vx, vy = value(x), value(y) - result = fma(vx, vy, value(z)) - return Dual(result, - _mul_partials(partials(x), partials(y), vy, vx) + partials(z)) +@generated function Base.fma{N}(x::Dual{N}, y::Dual{N}, z::Dual{N}) + ex = Expr(:tuple, [:(fma(value(x), partials(y)[$i], fma(value(y), partials(x)[$i], partials(z)[$i]))) for i in 1:N]...) + return quote + $(Expr(:meta, :inline)) + v = fma(value(x), value(y), value(z)) + Dual(v, $ex) + end end @inline function Base.fma(x::Dual, y::Dual, z::Real) @@ -423,10 +425,13 @@ end return Dual(result, _mul_partials(partials(x), partials(y), vy, vx)) end -@inline function Base.fma(x::Dual, y::Real, z::Dual) - vx = value(x) - result = fma(vx, y, value(z)) - return Dual(result, partials(x) * y + partials(z)) +@generated function Base.fma{N}(x::Dual{N}, y::Real, z::Dual{N}) + ex = Expr(:tuple, [:(fma(partials(x)[$i], y, partials(z)[$i])) for i in 1:N]...) + return quote + $(Expr(:meta, :inline)) + v = fma(value(x), y, value(z)) + Dual(v, $ex) + end end @inline Base.fma(x::Real, y::Dual, z::Dual) = fma(y, x, z) diff --git a/test/DualTest.jl b/test/DualTest.jl index 1ede4cad..53dc610e 100644 --- a/test/DualTest.jl +++ b/test/DualTest.jl @@ -15,8 +15,6 @@ samerng() = MersenneTwister(1) # exponent by one intrand(T) = T == Int ? rand(2:10) : rand(T) -dualapprox(A, B) = value(A) ≈ value(B) && partials(A) ≈ partials(B) - # fix testing issue with Base.hypot(::Int...) undefined in 0.4 if v"0.4" <= VERSION < v"0.5" Base.hypot(x::Int, y::Int) = Base.hypot(Float64(x), Float64(y))