-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support IndependentConstraint in Distribution._infer_param_domain #432
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix, Eli!
@fritzo I guess we can add
if isinstance(base_constraint, _IndependentConstraint):
batch_ndims = base_constraint.batch_ndims + batch_ndims
base_constraint = base_constraint.base_constraint
to simplify the while
loop here. What do you think?
Note the constraint names will probably be made public in the future so we should do something like support_name = type(support).__name__.lstrip("_") and reference names like |
funsor/distribution.py
Outdated
@@ -249,23 +249,30 @@ def _infer_param_domain(cls, name, raw_shape): | |||
# `infer_param_domain` methods. | |||
# Because NumPyro and Pyro have the same pattern, we use name check for simplicity. | |||
support_name = type(support).__name__ | |||
|
|||
event_dim = 0 | |||
while support_name == "_IndependentConstraint": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you change this to
support_name = type(support).__name__.lstrip("_")
event_dim = 0
while support_name == "IndependentConstraint":
...
so as to be compatible with all backends? Note in Pyro the class is named IndependentConstraint
; in PyTorch the class will initially be named _IndependentConstraint
; and in a future PyTorch the class will be renamed to IndependentConstraint
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also duplicate _infer_param_domain
in each backend so that dispatch happens on constraint types rather than names.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and in a future PyTorch the class will be renamed to IndependentConstraint
Does it apply to other constraints as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, see plans in pytorch/pytorch#50616
Addresses #386. Blocking all other PRs.
This PR updates the logic in
Distribution._infer_param_domain
to handle the new_IndependentConstraint
in NumPyro added in pyro-ppl/numpyro#876, which should fix the CI failures in #430 and other PRs.Tested:
test_distributions_generic::test_generic_sample
forCategoricalLogits