-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NotImplementedError
when using apply_optimizer
#560
Comments
The funsor term above after the application of a = Tensor(
torch.tensor([True, True, True], dtype=torch.bool),
(
(
"plate_outer__BOUND_79",
Bint[3],
),
),
2,
)
b = Tensor(
torch.tensor(
[-6.521620571590056, -2.191898006015123, -2.8022213468216526],
dtype=torch.float64,
), # noqa
(
(
"plate_outer__BOUND_79",
Bint[3],
),
),
"real",
)
result = Contraction(
ops.add, ops.mul, frozenset({Variable("plate_outer__BOUND_79", Bint[3])}), (a, b)
) which is not implemented for a boolean type. Manually converting |
That's weird and unexpected: we shouldn't be adding or multiplying bool tensors. Can you determine where the bool tensor comes from? Maybe there's some bad low-level math in one of the eager patterns? |
@fritzo forgot to mention that I'm getting this result with a funsor branch where some lines are commented out (https://github.com/pyro-ppl/funsor/compare/normalize-logaddexp). I opened an issue related to that (#561). Maybe fixing #561 first will also fix this issue. |
Bool tensor results from import torch
import funsor
from funsor.tensor import Tensor
import funsor.ops as ops
from funsor import Bint, Real
from funsor.terms import eager, lazy
funsor.set_backend("torch")
with lazy:
value = Tensor(torch.tensor([-1.5, 0.38, -1.02]))["plate"]
term = (value == value).all().log().exp()
with funsor.interpretations.normalize:
norm_term = funsor.interpreter.reinterpret(term)
with eager:
# a.dype == "real"
a = funsor.interpreter.reinterpret(term)
# b.dype == 2
b = funsor.interpreter.reinterpret(norm_term) |
Good diagnosis @ordabayevy I wonder if it's enough to add an - term = (value == value).all().log().exp()
+ term = ops.astype((value == value).all(), value.dtype).log().exp() |
@fritzo currently |
In principle it seems fine to add a |
Was addressed #562 |
This term below comes up in
test_pyroapi_funsor
:Error message:
The text was updated successfully, but these errors were encountered: