diff --git a/src/dual.jl b/src/dual.jl index 0cfb7182..8610bd0f 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -410,11 +410,13 @@ end @ambiguous @inline Base.atan2(y::Real, x::Dual) = calc_atan2(y, x) @ambiguous @inline Base.atan2(y::Dual, x::Real) = calc_atan2(y, x) -@inline function Base.fma(x::Dual, y::Dual, z::Dual) - vx, vy = value(x), value(y) - result = fma(vx, vy, value(z)) - return Dual(result, - _mul_partials(partials(x), partials(y), vy, vx) + partials(z)) +@generated function Base.fma{N}(x::Dual{N}, y::Dual{N}, z::Dual{N}) + ex = Expr(:tuple, [:(fma(value(x), partials(y)[$i], fma(value(y), partials(x)[$i], partials(z)[$i]))) for i in 1:N]...) + return quote + $(Expr(:meta, :inline)) + v = fma(value(x), value(y), value(z)) + Dual(v, $ex) + end end @inline function Base.fma(x::Dual, y::Dual, z::Real) @@ -423,10 +425,13 @@ end return Dual(result, _mul_partials(partials(x), partials(y), vy, vx)) end -@inline function Base.fma(x::Dual, y::Real, z::Dual) - vx = value(x) - result = fma(vx, y, value(z)) - return Dual(result, partials(x) * y + partials(z)) +@generated function Base.fma{N}(x::Dual{N}, y::Real, z::Dual{N}) + ex = Expr(:tuple, [:(fma(partials(x)[$i], y, partials(z)[$i])) for i in 1:N]...) + return quote + $(Expr(:meta, :inline)) + v = fma(value(x), y, value(z)) + Dual(v, $ex) + end end @inline Base.fma(x::Real, y::Dual, z::Dual) = fma(y, x, z)