From b6fb8eb9a6e3e75f4d439e178e7aa01cbb233ffe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20D=C3=B6ring?= Date: Fri, 7 Mar 2025 12:35:20 +0100 Subject: [PATCH 1/3] Added failing test for backpropagating to source of scatter to loop state --- tests/test_autodiff.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py index a9439c7e..4ff8ad82 100644 --- a/tests/test_autodiff.py +++ b/tests/test_autodiff.py @@ -2122,3 +2122,31 @@ def test137_forward_from_existing_gradient(t): a.grad = 1000 dr.forward_from(a) assert dr.allclose(b.grad, 2000) + +@pytest.test_arrays("is_diff,float,shape=(*)") +def test138_loop_state(t): + + Float = dr.float32_array_t(t) + + @dr.syntax(print_code=True) + def loop(t, x: Float, y: Float, n: int = 10) -> Float: + UInt32 = dr.uint32_array_t(t) + + i = UInt32(0) + while dr.hint(i < n, max_iterations=-1): + y += 1 + dr.scatter(x, y, i % 3) + i += 1 + + return y + + x = dr.arange(Float, 3) + dr.make_opaque(x) + + y = dr.arange(Float, 10) + dr.make_opaque(y) + dr.enable_grad(y) + + y = loop(t, x, y) + + dr.backward(y) From f682c251894e3b693e1fbec22ad64cf02a9c291c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20D=C3=B6ring?= Date: Fri, 7 Mar 2025 13:03:20 +0100 Subject: [PATCH 2/3] Removed print_syntax --- tests/test_autodiff.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py index 4ff8ad82..f0e803e5 100644 --- a/tests/test_autodiff.py +++ b/tests/test_autodiff.py @@ -2128,7 +2128,7 @@ def test138_loop_state(t): Float = dr.float32_array_t(t) - @dr.syntax(print_code=True) + @dr.syntax def loop(t, x: Float, y: Float, n: int = 10) -> Float: UInt32 = dr.uint32_array_t(t) From 65cb43718ae155cbd1e0ed050c9478acf568344d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20D=C3=B6ring?= Date: Fri, 7 Mar 2025 14:17:27 +0100 Subject: [PATCH 3/3] Renamed test --- tests/test_autodiff.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py index f0e803e5..5af7b3f9 100644 --- a/tests/test_autodiff.py +++ b/tests/test_autodiff.py @@ -2124,7 +2124,7 @@ def test137_forward_from_existing_gradient(t): assert dr.allclose(b.grad, 2000) @pytest.test_arrays("is_diff,float,shape=(*)") -def test138_loop_state(t): +def test138_loop_state_backprop(t): Float = dr.float32_array_t(t)