Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NotImplementedError when using apply_optimizer #560

Closed
ordabayevy opened this issue Sep 30, 2021 · 8 comments
Closed

NotImplementedError when using apply_optimizer #560

ordabayevy opened this issue Sep 30, 2021 · 8 comments
Assignees
Labels
enhancement New feature or request

Comments

@ordabayevy
Copy link
Member

This term below comes up in test_pyroapi_funsor:

from funsor.cnf import Contraction
from funsor.tensor import Tensor
import torch
import funsor.ops as ops
from funsor import Bint, Real
from funsor.terms import Unary, Binary, Variable, Number, lazy, to_data
from funsor.constant import Constant
from funsor.delta import Delta
import funsor
funsor.set_backend("torch")

with lazy:
    term = Contraction(ops.add, ops.mul,
     frozenset({Variable('plate_inner__BOUND_78', Bint[2]), Variable('x__BOUND_77', Real), Variable('plate_outer__BOUND_79', Bint[3])}),  # noqa
     (Unary(ops.exp,
       Contraction(ops.null, ops.add,
        frozenset(),
        (Delta(
          (('x__BOUND_77',
            (Tensor(
              torch.tensor([-1.5227684840130555, 0.38168390241818895, -1.0276085758313882], dtype=torch.float64),  # noqa
              (('plate_outer__BOUND_79',
                Bint[3],),),
              'real'),
             Number(0.0),),),)),
         Constant(
          (('plate_inner__BOUND_78',
            Bint[2],),),
          Tensor(
           torch.tensor([0.0, 0.0, 0.0], dtype=torch.float64),
           (('plate_outer__BOUND_79',
             Bint[3],),),
           'real')),))),
      Unary(ops.all,
       Binary(ops.eq,
        Tensor(
         torch.tensor([-1.5227684840130555, 0.38168390241818895, -1.0276085758313882], dtype=torch.float64),  # noqa
         (('plate_outer__BOUND_79',
           Bint[3],),),
         'real'),
        Tensor(
         torch.tensor([-1.5227684840130555, 0.38168390241818895, -1.0276085758313882], dtype=torch.float64),  # noqa
         (('plate_outer__BOUND_79',
           Bint[3],),),
         'real'))),
      Tensor(
       torch.tensor([[-3.304130052277938, -0.9255234395261538, -1.5122103473560844], [-3.217490519312117, -1.2663745664889694, -1.2900109994655682]],
dtype=torch.float64),  # noqa
       (('plate_inner__BOUND_78',
         Bint[2],),
        ('plate_outer__BOUND_79',
         Bint[3],),),
       'real'),))

x = to_data(funsor.optimizer.apply_optimizer(term))

Error message:

Traceback (most recent call last):
  File "hehe.py", line 54, in <module>
    x = to_data(funsor.optimizer.apply_optimizer(term))
  File "/home/ordabayev/repos/funsor/funsor/optimizer.py", line 169, in apply_optimizer
    return interpreter.reinterpret(expr)
  File "/home/ordabayev/repos/funsor/funsor/interpreter.py", line 255, in reinterpret
    return recursion_reinterpret(x)
  File "/home/ordabayev/repos/funsor/funsor/interpreter.py", line 236, in recursion_reinterpret
    return _STACK[-1].interpret(type(x), *map(recursion_reinterpret, children(x)))
  File "/home/ordabayev/repos/funsor/funsor/interpretations.py", line 196,
in interpret
    result = s.interpret(cls, *args)
  File "/home/ordabayev/repos/funsor/funsor/interpretations.py", line 155,
in interpret
    return self.dispatch(cls, *args)(*args)
  File "/home/ordabayev/repos/funsor/funsor/optimizer.py", line 87, in optimize_contraction_variadic
    return optimize.interpret(Contraction, r, b, v, tuple(ts))
  File "/home/ordabayev/repos/funsor/funsor/interpretations.py", line 196,
in interpret
    result = s.interpret(cls, *args)
  File "/home/ordabayev/repos/funsor/funsor/interpretations.py", line 155,
in interpret
    return self.dispatch(cls, *args)(*args)
  File "/home/ordabayev/repos/funsor/funsor/optimizer.py", line 145, in optimize_contract_finitary_funsor
    path_end = Contraction(
  File "/home/ordabayev/repos/funsor/funsor/terms.py", line 210, in __call__
    return interpret(cls, *args)
  File "/home/ordabayev/repos/funsor/funsor/interpretations.py", line 196,
in interpret
    result = s.interpret(cls, *args)
  File "/home/ordabayev/repos/funsor/funsor/interpretations.py", line 155,
in interpret
    return self.dispatch(cls, *args)(*args)
  File "/home/ordabayev/repos/funsor/funsor/cnf.py", line 323, in eager_contraction_tensor
    raise NotImplementedError("TODO")
NotImplementedError: TODO
@ordabayevy ordabayevy self-assigned this Sep 30, 2021
@fritzo fritzo added the enhancement New feature or request label Oct 1, 2021
@ordabayevy
Copy link
Member Author

The funsor term above after the application of optimize_contract_finitary_funsor can be boiled down to a contraction:

a = Tensor(
    torch.tensor([True, True, True], dtype=torch.bool),
    (
        (
            "plate_outer__BOUND_79",
            Bint[3],
        ),
    ),
    2,
)
b = Tensor(
    torch.tensor(
        [-6.521620571590056, -2.191898006015123, -2.8022213468216526],
        dtype=torch.float64,
    ),  # noqa
    (
        (
            "plate_outer__BOUND_79",
            Bint[3],
        ),
    ),
    "real",
)
result = Contraction(
    ops.add, ops.mul, frozenset({Variable("plate_outer__BOUND_79", Bint[3])}), (a, b)
)

which is not implemented for a boolean type. Manually converting a to a float type fixes the problem. But I don't know what would be a proper way to fix this in general @eb8680 @fritzo .

@fritzo
Copy link
Member

fritzo commented Oct 1, 2021

That's weird and unexpected: we shouldn't be adding or multiplying bool tensors. Can you determine where the bool tensor comes from? Maybe there's some bad low-level math in one of the eager patterns?

@ordabayevy
Copy link
Member Author

@fritzo forgot to mention that I'm getting this result with a funsor branch where some lines are commented out (https://github.com/pyro-ppl/funsor/compare/normalize-logaddexp). I opened an issue related to that (#561). Maybe fixing #561 first will also fix this issue.

@ordabayevy
Copy link
Member Author

Bool tensor results from (value == point).all().log() in Delta.eager_subs which is later exponentiated ((value == point).all().log().exp()). Evaluating this term eagerly results in real type. However, applying normalize first cancels out .log().exp() and returns just (value == point).all() which upon eager evaluation returns bool type:

import torch
import funsor
from funsor.tensor import Tensor
import funsor.ops as ops
from funsor import Bint, Real
from funsor.terms import eager, lazy

funsor.set_backend("torch")

with lazy:
    value = Tensor(torch.tensor([-1.5, 0.38, -1.02]))["plate"]
    term = (value == value).all().log().exp()

with funsor.interpretations.normalize:
    norm_term = funsor.interpreter.reinterpret(term)

with eager:
    # a.dype == "real"
    a = funsor.interpreter.reinterpret(term)
    # b.dype == 2
    b = funsor.interpreter.reinterpret(norm_term)

@fritzo
Copy link
Member

fritzo commented Oct 3, 2021

Good diagnosis @ordabayevy I wonder if it's enough to add an ops.astype(-, value.dtype)?

- term = (value == value).all().log().exp()
+ term = ops.astype((value == value).all(), value.dtype).log().exp()

@ordabayevy
Copy link
Member Author

@fritzo currently ops.astype doesn't change the Funsor.dtype, only the dtype of the underlying Tensor.data. Does it seem reasonable to modify ops.astype to aslo update Funsor.dtype? E.g., bool -> Bint[2], float, double -> Real. Integer type might be a bit more complicated -- maybe just preserve the dtype int -> Bint[n] if Funsor.dtype == Bint[n] and int -> Real if Funsor.dtype == Real.

@fritzo
Copy link
Member

fritzo commented Oct 4, 2021

In principle it seems fine to add a find_domain() rule for ops.astype, but we should think carefully about what kinds of things are allowed as dtype. Basically I'd like to keep the optimization .log().exp() but optimize it to ops.astype(-, "real"). But this is complex, I leave it to you 😄

@ordabayevy
Copy link
Member Author

Was addressed #562

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants