You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
What kind of problems is it mostly used for? Please describe.
Sometimes, the explicit form of a gradient of neural networks is required.
A (specific) but good example is the optimal experimental design of UDEs, where explicit sensitivity equations are needed to augment the system ( if possible in a functional form ). Right now, this is hardly possible ( or just for very small nets ), mostly due to compile issues of the resulting gradients. ( At least on my machine, M3, 18 GB Ram ).
I think here would be a good place to store this :). Otherwise I would move it into a separate Repository.
Describe the algorithm you’d like
Instead of deriving the gradient of the full chain in a single sweep, the symbolic augmentation of a simple Chain consisting of Dense layers can be done very easily using the chain rule. In a sense, this would provide the forward sensitivity of a full chain in a structured manor.
A MWE I have been cooking up:
using Lux
using Random
using ComponentArrays
using Symbolics
using Symbolics.SymbolicUtils.Code
using Symbolics.SymbolicUtils
# Just register all the activations from NNlib and their gradients here.∇swish(x::T) where T = (1+exp(-x)+x*exp(-x))/(1+exp(-x))^2myswish(x::T) where T = x/(1+exp(-x))
Symbolics.derivative(::typeof(myswish), (x,)::Tuple, ::Base.Val{1}) =begin∇swish(x)
model = Lux.Chain(
Dense(7, 10, myswish),
Dense(10, 10, myswish),
Dense(10, 3, myswish)
p0, st = Lux.setup(Random.default_rng(), model)
p0 =ComponentArray(p0)
using MacroTools
# Maybe this is not needed, I've added this to speed up the compilationfunctionsimplify_expression(val, us, ps)
varmatcher =let psyms =toexpr.(ps), usyms =toexpr.(us)
(x) ->beginif x ∈ usyms
id =findfirst(==(x), usyms)
return :(getindex(x, $(id)))
endif x ∈ psyms
id =findfirst(==(x), psyms)
return :(getindex(p, $(id)))
endreturn x
returns = [gensym() for i ineachindex(val)]
body = Expr[]
# This simplifies the inputs etcfor i ineachindex(returns)
subex =toexpr(val[i])
subex = MacroTools.postwalk(varmatcher, subex)
:($(returns[i]) =$(subex))
end# Return the right shapepush!(
:(returnreshape([$(returns...)], $(size(val)))::AbstractMatrix{promote_type(T, P)})
# Make the right signature
:(function (x::AbstractVector{T},p::AbstractVector{P}) where {T, P}
endstruct GradientLayer{L, DU, DP} <:Lux.AbstractExplicitLayer
(; in_dims, out_dims, activation) = layer
p = Lux.LuxCore.initialparameters(Random.default_rng(), layer)
# Make symbolic parameters
nparams =sum(prod ∘ size, p)
ps = Symbolics.variables(gensym(),Base.OneTo(nparams))
us = Symbolics.variables(gensym(), Base.OneTo(in_dims))
W =reshape(ps[1:in_dims*out_dims], out_dims, in_dims)
b = ps[(in_dims*out_dims+1):end]
ex =activation.(W*us+b)
dfdu = Symbolics.jacobian(ex, us)
dfdp = Symbolics.jacobian(ex, ps)
# Build the gradient w.r.t. to input and parameters
dfduex =simplify_expression(dfdu, us, ps)
dfdpex =simplify_expression(dfdp, us, ps)
# Build the function
dfdu =eval(dfduex)
dfdp =eval(dfdpex)
returnGradientLayer{typeof(layer), typeof(dfdu), typeof(dfdp)}(
layer, dfdu, dfdp
Lux.LuxCore.initialparameters(rng::Random.AbstractRNG, layer::GradientLayer) = LuxCore.initialparameters(rng, layer.layer)
Lux.LuxCore.initialstates(rng::Random.AbstractRNG, layer::GradientLayer) = LuxCore.initialstates(rng, layer.layer)
function (glayer::GradientLayer)(u::AbstractArray, ps, st::NamedTuple)
pvec =reduce(vcat, ps)
∇u_i = glayer.du(u, pvec)
∇p_i = glayer.dp(u, pvec)
next, st = glayer.layer(u, ps, st)
return (next, ∇u_i, ∇p_i), st
endfunction (glayer::GradientLayer)((u, du, dp)::Tuple, ps, st::NamedTuple)
pvec =reduce(vcat, ps)
∇u_i = glayer.du(u, pvec)
∇p_i = glayer.dp(u, pvec)
next, st = glayer.layer(u, ps, st)
# Note: This assumes right now a sequential chain. More advanced layers would probably need a dispatchreturn (next, ∇u_i * du, hcat(∇u_i*dp, ∇p_i)), st
endfunctionsymbolify(chain::Lux.Chain, p, st, name)
new_layers =map(GradientLayer, chain.layers)
new_chain = Lux.Chain(new_layers...; name)
(new_chain, Lux.setup(Random.default_rng(), new_chain...))
endfunctionsymbolify(layer::Lux.Dense, p, st, name)
new_layer =GradientLayer(layer)
(new_layer, Lux.setup(Random.default_rng(), new_layer)...)
end# In general, we can just `symbolify` a Chain or add a `GradientChain` constructor here.
newmodel, newps, newst = Lux.Experimental.layer_map(symbolify, model, p0, st);
u0 =rand(7)
ret, _ =newmodel(u0, newps, newst)
@code_warntypenewmodel(u0, newps, newst)
using Zygote
du_zyg = Zygote.jacobian(u->first(model(u, newps, newst)), u0)
dp_zyg = Zygote.jacobian(u->first(model(u0, u, newst)), newps)
# Returns a triplet (y, dy/du, dy/dp)
rets, _ =newmodel(u0, newps, newst);
Other implementations to know about
I don't know of any.
Just for completeness the preprint for the optimal experimental design.
The text was updated successfully, but these errors were encountered:
What kind of problems is it mostly used for? Please describe.
Sometimes, the explicit form of a gradient of neural networks is required.
A (specific) but good example is the optimal experimental design of UDEs, where explicit sensitivity equations are needed to augment the system ( if possible in a functional form ). Right now, this is hardly possible ( or just for very small nets ), mostly due to compile issues of the resulting gradients. ( At least on my machine, M3, 18 GB Ram ).
I think here would be a good place to store this :). Otherwise I would move it into a separate Repository.
Describe the algorithm you’d like
Instead of deriving the gradient of the full chain in a single sweep, the symbolic augmentation of a simple
consisting ofDense
layers can be done very easily using the chain rule. In a sense, this would provide the forward sensitivity of a full chain in a structured manor.A MWE I have been cooking up:
Other implementations to know about
I don't know of any.
Just for completeness the preprint for the optimal experimental design.
The text was updated successfully, but these errors were encountered: