|
13 | 13 | #include <mitsuba/core/util.h>
|
14 | 14 | #include <drjit/dynamic.h>
|
15 | 15 | #include <array>
|
| 16 | +#include <drjit/traversable_base.h> |
16 | 17 |
|
17 | 18 | NAMESPACE_BEGIN(mitsuba)
|
18 | 19 |
|
@@ -72,7 +73,7 @@ NAMESPACE_BEGIN(mitsuba)
|
72 | 73 | */
|
73 | 74 |
|
74 | 75 | template <typename Float_, size_t Dimension_ = 0>
|
75 |
| -class DiscreteDistribution2D { |
| 76 | +class DiscreteDistribution2D : drjit::TraversableBase{ |
76 | 77 | public:
|
77 | 78 | using Float = Float_;
|
78 | 79 | using UInt32 = dr::uint32_array_t<Float>;
|
@@ -201,10 +202,14 @@ class DiscreteDistribution2D {
|
201 | 202 |
|
202 | 203 | Float m_inv_normalization;
|
203 | 204 | Float m_normalization;
|
| 205 | + |
| 206 | + DR_TRAVERSE_CB(drjit::TraversableBase, m_data, m_marg_cdf, m_cond_cdf, |
| 207 | + m_inv_normalization, m_normalization) |
204 | 208 | };
|
205 | 209 |
|
206 | 210 | /// Base class of Hierarchical2D and Marginal2D with common functionality
|
207 |
| -template <typename Float_, size_t Dimension_ = 0> class Distribution2D { |
| 211 | +template <typename Float_, size_t Dimension_ = 0> |
| 212 | +class Distribution2D : drjit::TraversableBase { |
208 | 213 | public:
|
209 | 214 | static constexpr size_t Dimension = Dimension_;
|
210 | 215 | using Float = Float_;
|
@@ -308,6 +313,28 @@ template <typename Float_, size_t Dimension_ = 0> class Distribution2D {
|
308 | 313 |
|
309 | 314 | /// Total number of slices (in case Dimension > 1)
|
310 | 315 | uint32_t m_slices;
|
| 316 | + |
| 317 | +public: |
| 318 | + void |
| 319 | + traverse_1_cb_ro(void *payload, |
| 320 | + drjit::detail::traverse_callback_ro fn) const override { |
| 321 | + if constexpr (!std ::is_same_v<drjit ::TraversableBase, |
| 322 | + drjit ::TraversableBase>) |
| 323 | + drjit ::TraversableBase ::traverse_1_cb_ro(payload, fn); |
| 324 | + for (const auto ¶m_value : m_param_values) { |
| 325 | + drjit ::traverse_1_fn_ro(param_value, payload, fn); |
| 326 | + } |
| 327 | + } |
| 328 | + void traverse_1_cb_rw(void *payload, |
| 329 | + drjit::detail::traverse_callback_rw fn) override { |
| 330 | + if constexpr (!std ::is_same_v<drjit ::TraversableBase, |
| 331 | + drjit ::TraversableBase>) |
| 332 | + drjit ::TraversableBase ::traverse_1_cb_rw(payload, fn); |
| 333 | + |
| 334 | + for (auto ¶m_value : m_param_values) { |
| 335 | + drjit ::traverse_1_fn_rw(param_value, payload, fn); |
| 336 | + } |
| 337 | + } |
311 | 338 | };
|
312 | 339 |
|
313 | 340 | /**
|
@@ -788,13 +815,17 @@ class Hierarchical2D : public Distribution2D<Float_, Dimension_> {
|
788 | 815 | return dr::gather<Float>(data, i0, active);
|
789 | 816 | }
|
790 | 817 | }
|
| 818 | + |
| 819 | + DRJIT_STRUCT_NODEF(Level, data) |
791 | 820 | };
|
792 | 821 |
|
793 | 822 | /// MIP hierarchy over linearly interpolated patches
|
794 | 823 | std::vector<Level> m_levels;
|
795 | 824 |
|
796 | 825 | /// Number of bilinear patches in the X/Y dimension - 1
|
797 | 826 | ScalarVector2u m_max_patch_index;
|
| 827 | + |
| 828 | + DR_TRAVERSE_CB(Base, m_levels) |
798 | 829 | };
|
799 | 830 |
|
800 | 831 | /**
|
@@ -1454,6 +1485,8 @@ class Marginal2D : public Distribution2D<Float_, Dimension_> {
|
1454 | 1485 |
|
1455 | 1486 | /// Are the probability values normalized?
|
1456 | 1487 | bool m_normalized;
|
| 1488 | + |
| 1489 | + DR_TRAVERSE_CB(Base, m_data, m_marg_cdf, m_cond_cdf) |
1457 | 1490 | };
|
1458 | 1491 |
|
1459 | 1492 | //! @}
|
|
0 commit comments