-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support functions that splat namedtuples as keyword arguments #1059
Conversation
g(; kwargs...) = kwargs[:x] * kwargs[:z] | ||
h(somedata) = g(; somedata...) | ||
@test gradient(h, (; x=3.0, y=4.0, z=2.3)) == ((x = 2.3, y = 0.0, z = 3.0),) | ||
@test gradient(h, Dict(:x=>3.0, :y=>4.0, :z=>2.3)) == ((y = 0.0, z = 3.0, x = 2.3),) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure what type Zygote wants to use to represent Dict
.
accum
isn't defined for Dict
AFAICT.
So NamedTuple
seemed reasonable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the Dict
story is not consistent at the moment and there are many missing features
https://github.com/FluxML/Zygote.jl/issues?q=is%3Aissue+is%3Aopen+label%3Adictionary
I guess generally the returned gradient should be a dictionary, but for dicts with symbol keys maybe a namedtuple is good enough for the time being
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the follow up!
if VERSION >= v"1.6" | ||
@adjoint merge(nt::NamedTuple, dict::Dict) = pullback(merge, nt, NamedTuple(dict)) | ||
else | ||
@adjoint merge(nt::NamedTuple, dict::Dict) = pullback(merge, nt, (;dict...)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just keep this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean?
You mean use if for both 1.6 and pre-1.6?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, exactly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
The CUDA failures seem related to CUDA.jl v3.4 -- We should get some kind of fix in Flux/ Zygote or CUDA since |
Bump |
CI failures are unrelated. |
@lsindoni ran into this.
This problem occurs for Dict and NamedTuples with a error about the methods of
Base.setindex
.The
Dict
in thepairs
pullback in this case ends up containingSymbol
s notInteger
s.So we branch to support that.
(unfortunately it is a
Dict{Any,Any}
so we can't branch earlier, and so it is dynamic dispatch).The chain to the rule for
pairs
is enough to make NamedTuple's work,If we also want to support Dicts, we need the rule for
merge
as wellI am surprised noone has run into this before.