From 1751cf7708a4bd28d5e1b1b047a6b65849bfa693 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20P=C3=A9rez=20=C3=81lvarez?= <72174660+Tomperez98@users.noreply.github.com> Date: Thu, 30 Jan 2025 15:43:20 -0500 Subject: [PATCH] Fix rfi from string in same node (#139) * Fix rfi from string in same node * Remove unused script --- src/resonate/record.py | 41 ++++++++++++++++------------- src/resonate/scheduler/scheduler.py | 2 +- tests/test_functionality.py | 4 +-- 3 files changed, 26 insertions(+), 21 deletions(-) diff --git a/src/resonate/record.py b/src/resonate/record.py index c7a16a99..20e41d05 100644 --- a/src/resonate/record.py +++ b/src/resonate/record.py @@ -21,6 +21,24 @@ T = TypeVar("T") +def _set_retry_policy(invocation: LFI | RFI) -> retry_policy.RetryPolicy | None: + if not isinstance(invocation.unit, Invocation): + retry_policy = None + elif isinstance(invocation.unit.fn, str): + retry_policy = invocation.opts.retry_policy or None + elif isgeneratorfunction(invocation.unit.fn): + retry_policy = invocation.opts.retry_policy or never() + else: + assert iscoroutinefunction(invocation.unit.fn) or isfunction(invocation.unit.fn) + retry_policy = invocation.opts.retry_policy or exponential( + base_delay=1, + factor=2, + max_retries=-1, + max_delay=30, + ) + return retry_policy + + @final class Record(Generic[T]): def __init__( @@ -38,24 +56,7 @@ def __init__( self._result: Result[T, Exception] | None = None self.children: list[Record[Any]] = [] self.invocation: LFI | RFI = invocation - self.retry_policy: retry_policy.RetryPolicy | None - if not isinstance(invocation.unit, Invocation): - self.retry_policy = None - elif isinstance(invocation.unit.fn, str): - self.retry_policy = invocation.opts.retry_policy or None - elif isgeneratorfunction(invocation.unit.fn): - self.retry_policy = invocation.opts.retry_policy or never() - else: - assert iscoroutinefunction(invocation.unit.fn) or isfunction( - invocation.unit.fn - ) - self.retry_policy = invocation.opts.retry_policy or exponential( - base_delay=1, - factor=2, - max_retries=-1, - max_delay=30, - ) - + self.retry_policy = _set_retry_policy(invocation=invocation) self._attempt: int = 1 self.durable_promise: DurablePromiseRecord | None = None self._task: TaskRecord | None = None @@ -70,6 +71,10 @@ def __init__( self.parent.id if self.parent else None, ) + def overwrite_invocation(self, invocation: RFI) -> None: + self.invocation = invocation + self.retry_policy = _set_retry_policy(invocation) + def get_coro(self) -> ResonateCoro[T]: assert self._coro return self._coro diff --git a/src/resonate/scheduler/scheduler.py b/src/resonate/scheduler/scheduler.py index c6c22404..0ff7bd3b 100644 --- a/src/resonate/scheduler/scheduler.py +++ b/src/resonate/scheduler/scheduler.py @@ -422,7 +422,7 @@ def _process_invoke_msg( assert record.is_root assert isinstance(record.invocation, RFI) assert record.durable_promise is not None - record.invocation = rfi + record.overwrite_invocation(rfi) else: record = Record[Any]( id=invoke_msg.root_durable_promise.id, diff --git a/tests/test_functionality.py b/tests/test_functionality.py index 3b17707a..3d17c447 100644 --- a/tests/test_functionality.py +++ b/tests/test_functionality.py @@ -246,7 +246,7 @@ def test_golden_device_rfi() -> None: group = "test-golden-device-rfi" def foo_golden_device_rfi(ctx: Context, n: str) -> Generator[Yieldable, Any, str]: - p: Promise[str] = yield ctx.rfi(bar_golden_device_rfi, n).options( + p: Promise[str] = yield ctx.rfi("bar_golden_device_rfi", n).options( id="bar", send_to=poll(group) ) assert isinstance(p, Promise) @@ -281,7 +281,7 @@ def exec_id(n: int) -> str: def factorial_rfi(ctx: Context, n: int) -> Generator[Yieldable, Any, int]: if n == 0: return 1 - p = yield ctx.rfi(factorial_rfi, n - 1).options( + p = yield ctx.rfi("factorial_rfi", n - 1).options( id=exec_id(n - 1), send_to=poll(group) ) return n * (yield p)