Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Symbolic Gradients of Neural Networks #44

Open
AlCap23 opened this issue Dec 18, 2024 · 1 comment
Open

Symbolic Gradients of Neural Networks #44

AlCap23 opened this issue Dec 18, 2024 · 1 comment

Comments

@AlCap23
Copy link

AlCap23 commented Dec 18, 2024

What kind of problems is it mostly used for? Please describe.

Sometimes, the explicit form of a gradient of neural networks is required.

A (specific) but good example is the optimal experimental design of UDEs, where explicit sensitivity equations are needed to augment the system ( if possible in a functional form ). Right now, this is hardly possible ( or just for very small nets ), mostly due to compile issues of the resulting gradients. ( At least on my machine, M3, 18 GB Ram ).

I think here would be a good place to store this :). Otherwise I would move it into a separate Repository.

Describe the algorithm you’d like

Instead of deriving the gradient of the full chain in a single sweep, the symbolic augmentation of a simple Chain consisting of Dense layers can be done very easily using the chain rule. In a sense, this would provide the forward sensitivity of a full chain in a structured manor.

A MWE I have been cooking up:

using Lux 
using Random
using ComponentArrays
using Symbolics
using Symbolics.SymbolicUtils.Code
using Symbolics.SymbolicUtils

# Just register all the activations from NNlib and their gradients here.
∇swish(x::T) where T = (1+exp(-x)+x*exp(-x))/(1+exp(-x))^2
myswish(x::T) where T =  x/(1+exp(-x))

@register_symbolic myswish(x::Real)::Real
@register_symbolic ∇swish(x::Real)::Real

Symbolics.derivative(::typeof(myswish), (x,)::Tuple, ::Base.Val{1}) = begin
    ∇swish(x)
end

model = Lux.Chain(
    Dense(7, 10, myswish), 
    Dense(10, 10, myswish), 
    Dense(10, 3, myswish)
)

p0, st = Lux.setup(Random.default_rng(), model)
p0 = ComponentArray(p0)

using MacroTools

# Maybe this is not needed, I've added this to speed up the compilation
function simplify_expression(val, us, ps)
    varmatcher = let psyms = toexpr.(ps), usyms = toexpr.(us)
        (x) -> begin
            if x  usyms
                id = findfirst(==(x), usyms)
                return :(getindex(x, $(id)))
            end
            if x  psyms
                id = findfirst(==(x), psyms)
                return :(getindex(p, $(id)))
            end
            return x
        end
    end    

    returns = [gensym() for i in eachindex(val)]
    body = Expr[]
    # This simplifies the inputs etc
    for i in eachindex(returns)
        subex = toexpr(val[i])
        subex = MacroTools.postwalk(varmatcher, subex)
        push!(body, 
            :($(returns[i]) = $(subex))
        )
    end
    # Return the right shape
    push!(
        body, 
        :(return reshape([$(returns...)], $(size(val)))::AbstractMatrix{promote_type(T, P)})
    );
    # Make the right signature
    :(function (x::AbstractVector{T},p::AbstractVector{P}) where {T, P}
        $(body...)
    end)
end 

struct GradientLayer{L, DU, DP} <: Lux.AbstractExplicitLayer
    layer::L
    du::DU 
    dp::DP
end

function GradientLayer(layer::Lux.Dense)
    (; in_dims, out_dims, activation) = layer
    p = Lux.LuxCore.initialparameters(Random.default_rng(), layer)
    # Make symbolic parameters 
    nparams = sum(prod  size, p)
    ps = Symbolics.variables(gensym(),Base.OneTo(nparams))
    us = Symbolics.variables(gensym(), Base.OneTo(in_dims))
    W = reshape(ps[1:in_dims*out_dims], out_dims, in_dims)
    b = ps[(in_dims*out_dims+1):end]
    ex = activation.(W*us+b)
    dfdu = Symbolics.jacobian(ex, us)
    dfdp = Symbolics.jacobian(ex, ps)
    # Build the gradient w.r.t. to input and parameters
    dfduex = simplify_expression(dfdu, us, ps)
    dfdpex = simplify_expression(dfdp, us, ps)
    # Build the function 
    dfdu = eval(dfduex)
    dfdp = eval(dfdpex)
    return GradientLayer{typeof(layer), typeof(dfdu), typeof(dfdp)}(
        layer, dfdu, dfdp
    )
end

Lux.LuxCore.initialparameters(rng::Random.AbstractRNG, layer::GradientLayer) = LuxCore.initialparameters(rng, layer.layer)
Lux.LuxCore.initialstates(rng::Random.AbstractRNG, layer::GradientLayer) = LuxCore.initialstates(rng, layer.layer)

function (glayer::GradientLayer)(u::AbstractArray, ps, st::NamedTuple)
    pvec = reduce(vcat, ps)
    ∇u_i = glayer.du(u, pvec)
    ∇p_i = glayer.dp(u, pvec)
    next, st = glayer.layer(u, ps, st)
    return (next, ∇u_i, ∇p_i), st
end

function (glayer::GradientLayer)((u, du, dp)::Tuple, ps, st::NamedTuple)
    pvec = reduce(vcat, ps)
    ∇u_i = glayer.du(u, pvec) 
    ∇p_i = glayer.dp(u, pvec)
    next, st = glayer.layer(u, ps, st)
    # Note: This assumes right now a sequential chain. More advanced layers would probably need a dispatch
    return (next, ∇u_i * du, hcat(∇u_i*dp, ∇p_i)), st
end

function symbolify(chain::Lux.Chain, p, st, name)
    new_layers = map(GradientLayer, chain.layers)
    new_chain = Lux.Chain(new_layers...; name)
    (new_chain, Lux.setup(Random.default_rng(), new_chain...))
end

function symbolify(layer::Lux.Dense, p, st, name)
    new_layer = GradientLayer(layer)
    (new_layer, Lux.setup(Random.default_rng(), new_layer)...)
end

# In general, we can just `symbolify` a Chain or add a `GradientChain` constructor here.
newmodel, newps, newst = Lux.Experimental.layer_map(symbolify, model, p0, st);
u0 = rand(7)
ret, _ = newmodel(u0, newps, newst)
@code_warntype newmodel(u0, newps, newst)
using Zygote

du_zyg = Zygote.jacobian(u->first(model(u, newps, newst)), u0)
dp_zyg = Zygote.jacobian(u->first(model(u0, u, newst)), newps)

# Returns a triplet (y, dy/du, dy/dp)
rets, _ = newmodel(u0, newps, newst);

Other implementations to know about

I don't know of any.

References

Just for completeness the preprint for the optimal experimental design.

@ChrisRackauckas
Copy link
Member

For this to make sense, we'd someone to implement the matrix calculus in Symbolics.jl, which hasn't been done yet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants