Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support functions that splat namedtuples as keyword arguments #1059

Merged
merged 3 commits into from
Sep 10, 2021

Conversation

oxinabox
Copy link
Member

@oxinabox oxinabox commented Sep 2, 2021

@lsindoni ran into this.
This problem occurs for Dict and NamedTuples with a error about the methods of Base.setindex.

The Dict in the pairs pullback in this case ends up containing Symbols not Integers.
So we branch to support that.
(unfortunately it is a Dict{Any,Any} so we can't branch earlier, and so it is dynamic dispatch).

The chain to the rule for pairs is enough to make NamedTuple's work,
If we also want to support Dicts, we need the rule for merge as well

I am surprised noone has run into this before.

g(; kwargs...) = kwargs[:x] * kwargs[:z]
h(somedata) = g(; somedata...)
@test gradient(h, (; x=3.0, y=4.0, z=2.3)) == ((x = 2.3, y = 0.0, z = 3.0),)
@test gradient(h, Dict(:x=>3.0, :y=>4.0, :z=>2.3)) == ((y = 0.0, z = 3.0, x = 2.3),)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what type Zygote wants to use to represent Dict.
accum isn't defined for Dict AFAICT.
So NamedTuple seemed reasonable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the Dict story is not consistent at the moment and there are many missing features
https://github.com/FluxML/Zygote.jl/issues?q=is%3Aissue+is%3Aopen+label%3Adictionary
I guess generally the returned gradient should be a dictionary, but for dicts with symbol keys maybe a namedtuple is good enough for the time being

Copy link
Member

@DhairyaLGandhi DhairyaLGandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the follow up!

if VERSION >= v"1.6"
@adjoint merge(nt::NamedTuple, dict::Dict) = pullback(merge, nt, NamedTuple(dict))
else
@adjoint merge(nt::NamedTuple, dict::Dict) = pullback(merge, nt, (;dict...))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just keep this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean?
You mean use if for both 1.6 and pre-1.6?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, exactly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@DhairyaLGandhi
Copy link
Member

The CUDA failures seem related to CUDA.jl v3.4 -- We should get some kind of fix in Flux/ Zygote or CUDA since cu seems to have changed behaviour in a couple places. @vchuravy would you know what needs to be done here?

@oxinabox
Copy link
Member Author

Bump

@oxinabox
Copy link
Member Author

CI failures are unrelated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants