Skip to content

Commit 25c82f2

Browse files
Added traverse_callback functions for frozen functions support
1 parent 733f8af commit 25c82f2

File tree

119 files changed

+1809
-209
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

119 files changed

+1809
-209
lines changed

include/mitsuba/core/bitmap.h

+2
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,8 @@ class MI_EXPORT_LIB Bitmap : public Object {
652652
bool m_premultiplied_alpha;
653653
bool m_owns_data;
654654
Properties m_metadata;
655+
656+
DR_TRAVERSE_CB(Object, m_size);
655657
};
656658

657659

include/mitsuba/core/bsphere.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
NAMESPACE_BEGIN(mitsuba)
66

77
/// Generic n-dimensional bounding sphere data structure
8-
template <typename Point_> struct BoundingSphere {
8+
template <typename Point_> struct BoundingSphere: drjit::TraversableBase {
99
static constexpr size_t Size = Point_::Size;
1010
using Point = Point_;
1111
using Float = dr::value_t<Point>;
@@ -74,6 +74,8 @@ template <typename Point_> struct BoundingSphere {
7474
dr::squared_norm(o) - dr::square(radius)
7575
);
7676
}
77+
78+
DR_TRAVERSE_CB(drjit::TraversableBase, center, radius);
7779
};
7880

7981
/// Print a string representation of the bounding sphere

include/mitsuba/core/class.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ NAMESPACE_END(detail)
247247

248248
#define MI_REGISTRY_PUT(name, ptr) \
249249
if constexpr (dr::is_jit_v<Float>) { \
250-
jit_registry_put(::mitsuba::detail::get_variant<Float, Spectrum>(), \
250+
drjit::registry_put(::mitsuba::detail::get_variant<Float, Spectrum>(), \
251251
"mitsuba::" name, ptr); \
252252
}
253253

include/mitsuba/core/distr_1d.h

+14-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <mitsuba/core/vector.h>
55
#include <mitsuba/core/math.h>
66
#include <drjit/dynamic.h>
7+
#include <drjit/traversable_base.h>
78

89
NAMESPACE_BEGIN(mitsuba)
910

@@ -17,7 +18,7 @@ NAMESPACE_BEGIN(mitsuba)
1718
* initialization. The associated scale factor can be retrieved using the
1819
* function \ref normalization().
1920
*/
20-
template <typename Value> struct DiscreteDistribution {
21+
template <typename Value> struct DiscreteDistribution: drjit::TraversableBase {
2122
using Float = std::conditional_t<dr::is_static_array_v<Value>,
2223
dr::value_t<Value>, Value>;
2324
using FloatStorage = DynamicBuffer<Float>;
@@ -269,6 +270,9 @@ template <typename Value> struct DiscreteDistribution {
269270
Float m_sum = 0.f;
270271
Float m_normalization = 0.f;
271272
Vector2u m_valid;
273+
274+
DR_TRAVERSE_CB(drjit::TraversableBase, m_pmf, m_cdf, m_sum, m_normalization,
275+
m_valid);
272276
};
273277

274278
/**
@@ -283,7 +287,7 @@ template <typename Value> struct DiscreteDistribution {
283287
* initialization. The associated scale factor can be retrieved using the
284288
* function \ref normalization().
285289
*/
286-
template <typename Value> struct ContinuousDistribution {
290+
template <typename Value> struct ContinuousDistribution: drjit::TraversableBase {
287291
using Float = std::conditional_t<dr::is_static_array_v<Value>,
288292
dr::value_t<Value>, Value>;
289293
using FloatStorage = DynamicBuffer<Float>;
@@ -601,6 +605,10 @@ template <typename Value> struct ContinuousDistribution {
601605
ScalarVector2f m_range { 0.f, 0.f };
602606
Vector2u m_valid;
603607
ScalarFloat m_max = 0.f;
608+
609+
DR_TRAVERSE_CB(drjit::TraversableBase, m_pdf, m_cdf, m_integral,
610+
m_normalization, m_interval_size, m_inv_interval_size,
611+
m_valid);
604612
};
605613

606614
/**
@@ -615,7 +623,7 @@ template <typename Value> struct ContinuousDistribution {
615623
* initialization. The associated scale factor can be retrieved using the
616624
* function \ref normalization().
617625
*/
618-
template <typename Value> struct IrregularContinuousDistribution {
626+
template <typename Value> struct IrregularContinuousDistribution : public drjit::TraversableBase{
619627
using Float = std::conditional_t<dr::is_static_array_v<Value>,
620628
dr::value_t<Value>, Value>;
621629
using FloatStorage = DynamicBuffer<Float>;
@@ -973,6 +981,9 @@ template <typename Value> struct IrregularContinuousDistribution {
973981
Vector2u m_valid;
974982
ScalarFloat m_interval_size = 0.f;
975983
ScalarFloat m_max = 0.f;
984+
985+
DR_TRAVERSE_CB(drjit::TraversableBase, m_nodes, m_pdf, m_cdf, m_integral,
986+
m_normalization, m_valid);
976987
};
977988

978989
template <typename Value>

include/mitsuba/core/distr_2d.h

+35-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <mitsuba/core/util.h>
1414
#include <drjit/dynamic.h>
1515
#include <array>
16+
#include <drjit/traversable_base.h>
1617

1718
NAMESPACE_BEGIN(mitsuba)
1819

@@ -72,7 +73,7 @@ NAMESPACE_BEGIN(mitsuba)
7273
*/
7374

7475
template <typename Float_, size_t Dimension_ = 0>
75-
class DiscreteDistribution2D {
76+
class DiscreteDistribution2D : drjit::TraversableBase{
7677
public:
7778
using Float = Float_;
7879
using UInt32 = dr::uint32_array_t<Float>;
@@ -201,10 +202,14 @@ class DiscreteDistribution2D {
201202

202203
Float m_inv_normalization;
203204
Float m_normalization;
205+
206+
DR_TRAVERSE_CB(drjit::TraversableBase, m_data, m_marg_cdf, m_cond_cdf,
207+
m_inv_normalization, m_normalization)
204208
};
205209

206210
/// Base class of Hierarchical2D and Marginal2D with common functionality
207-
template <typename Float_, size_t Dimension_ = 0> class Distribution2D {
211+
template <typename Float_, size_t Dimension_ = 0>
212+
class Distribution2D : drjit::TraversableBase {
208213
public:
209214
static constexpr size_t Dimension = Dimension_;
210215
using Float = Float_;
@@ -308,6 +313,28 @@ template <typename Float_, size_t Dimension_ = 0> class Distribution2D {
308313

309314
/// Total number of slices (in case Dimension > 1)
310315
uint32_t m_slices;
316+
317+
public:
318+
void
319+
traverse_1_cb_ro(void *payload,
320+
drjit::detail::traverse_callback_ro fn) const override {
321+
if constexpr (!std ::is_same_v<drjit ::TraversableBase,
322+
drjit ::TraversableBase>)
323+
drjit ::TraversableBase ::traverse_1_cb_ro(payload, fn);
324+
for (const auto &param_value : m_param_values) {
325+
drjit ::traverse_1_fn_ro(param_value, payload, fn);
326+
}
327+
}
328+
void traverse_1_cb_rw(void *payload,
329+
drjit::detail::traverse_callback_rw fn) override {
330+
if constexpr (!std ::is_same_v<drjit ::TraversableBase,
331+
drjit ::TraversableBase>)
332+
drjit ::TraversableBase ::traverse_1_cb_rw(payload, fn);
333+
334+
for (auto &param_value : m_param_values) {
335+
drjit ::traverse_1_fn_rw(param_value, payload, fn);
336+
}
337+
}
311338
};
312339

313340
/**
@@ -788,13 +815,17 @@ class Hierarchical2D : public Distribution2D<Float_, Dimension_> {
788815
return dr::gather<Float>(data, i0, active);
789816
}
790817
}
818+
819+
DRJIT_STRUCT_NODEF(Level, data)
791820
};
792821

793822
/// MIP hierarchy over linearly interpolated patches
794823
std::vector<Level> m_levels;
795824

796825
/// Number of bilinear patches in the X/Y dimension - 1
797826
ScalarVector2u m_max_patch_index;
827+
828+
DR_TRAVERSE_CB(Base, m_levels)
798829
};
799830

800831
/**
@@ -1454,6 +1485,8 @@ class Marginal2D : public Distribution2D<Float_, Dimension_> {
14541485

14551486
/// Are the probability values normalized?
14561487
bool m_normalized;
1488+
1489+
DR_TRAVERSE_CB(Base, m_data, m_marg_cdf, m_cond_cdf)
14571490
};
14581491

14591492
//! @}

include/mitsuba/core/field.h

+18
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <type_traits>
66

77
#include <drjit/array.h>
8+
#include <drjit/array_traverse.h>
89
namespace dr = drjit;
910

1011
NAMESPACE_BEGIN(mitsuba)
@@ -62,6 +63,12 @@ struct field<DeviceType, HostType,
6263
}
6364
private:
6465
DeviceType m_scalar;
66+
67+
public:
68+
void traverse_1_cb_ro(void * /*payload*/,
69+
drjit::detail::traverse_callback_ro) const {}
70+
void traverse_1_cb_rw(void * /*payload*/,
71+
drjit::detail::traverse_callback_rw) {}
6572
};
6673

6774
template <typename DeviceType, typename HostType>
@@ -105,6 +112,17 @@ struct field<DeviceType, HostType,
105112
private:
106113
DeviceType m_value;
107114
HostType m_scalar;
115+
116+
public:
117+
void traverse_1_cb_ro(void *payload,
118+
drjit::detail::traverse_callback_ro fn) const {
119+
120+
drjit ::traverse_1_fn_ro(m_value, payload, fn);
121+
}
122+
void traverse_1_cb_rw(void *payload,
123+
drjit::detail::traverse_callback_rw fn) {
124+
drjit ::traverse_1_fn_rw(m_value, payload, fn);
125+
}
108126
};
109127

110128
/// Prints the canonical string representation of a field

include/mitsuba/core/fwd.h

+24
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,30 @@ extern "C" {
368368
})
369369
#endif
370370

371+
#define MI_DECLARE_TRAVERSE_CB() \
372+
public: \
373+
void traverse_1_cb_ro(void *payload, \
374+
drjit::detail::traverse_callback_ro fn) \
375+
const override; \
376+
void traverse_1_cb_rw( \
377+
void *payload, drjit::detail::traverse_callback_rw fn) override;
378+
379+
#define MI_IMPLEMENT_TRAVERSE_CB(Type, Base, ...) \
380+
MI_VARIANT \
381+
void Type<Float, Spectrum>::traverse_1_cb_ro( \
382+
void *payload, drjit::detail::traverse_callback_ro fn) const { \
383+
if constexpr (!std ::is_same_v<Base, drjit ::TraversableBase>) \
384+
Base ::traverse_1_cb_ro(payload, fn); \
385+
DRJIT_MAP(DR_TRAVERSE_MEMBER_RO, __VA_ARGS__) \
386+
} \
387+
MI_VARIANT \
388+
void Type<Float, Spectrum>::traverse_1_cb_rw( \
389+
void *payload, drjit::detail::traverse_callback_rw fn) { \
390+
if constexpr (!std ::is_same_v<Base, drjit ::TraversableBase>) \
391+
Base ::traverse_1_cb_rw(payload, fn); \
392+
DRJIT_MAP(DR_TRAVERSE_MEMBER_RW, __VA_ARGS__) \
393+
}
394+
371395
//! @}
372396
// =============================================================
373397

include/mitsuba/core/object.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <atomic>
44
#include <stdexcept>
55
#include <mitsuba/core/class.h>
6+
#include <drjit/traversable_base.h>
67

78
NAMESPACE_BEGIN(mitsuba)
89

@@ -29,8 +30,10 @@ NAMESPACE_BEGIN(mitsuba)
2930
* Python, this counter is shared with Python such that the ownerhsip and
3031
* lifetime of any ``Object`` instance across C++ and Python is managed by it.
3132
*/
32-
class MI_EXPORT_LIB Object : public nanobind::intrusive_base {
33+
class MI_EXPORT_LIB Object : public drjit::TraversableBase {
3334
public:
35+
DR_TRAVERSE_CB(drjit::TraversableBase)
36+
3437
/// Default constructor
3538
Object() { }
3639

include/mitsuba/render/bsdf.h

+2
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,8 @@ class MI_EXPORT_LIB BSDF : public Object {
609609

610610
/// Identifier (if available)
611611
std::string m_id;
612+
613+
DR_TRAVERSE_CB(Object);
612614
};
613615

614616
// -----------------------------------------------------------------------

include/mitsuba/render/emitter.h

+2
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ class MI_EXPORT_LIB Emitter : public Endpoint<Float, Spectrum> {
9494

9595
/// True if the emitters's parameters have changed
9696
bool m_dirty = false;
97+
98+
DR_TRAVERSE_CB(Base);
9799
};
98100

99101
MI_EXTERN_CLASS(Emitter)

include/mitsuba/render/endpoint.h

+2
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,8 @@ class MI_EXPORT_LIB Endpoint : public Object {
397397
bool m_needs_sample_2 = true;
398398
bool m_needs_sample_3 = true;
399399
std::string m_id;
400+
401+
MI_DECLARE_TRAVERSE_CB()
400402
};
401403

402404
MI_EXTERN_CLASS(Endpoint)

include/mitsuba/render/film.h

+2
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ class MI_EXPORT_LIB Film : public Object {
224224
bool m_sample_border;
225225
ref<ReconstructionFilter> m_filter;
226226
ref<Texture> m_srf;
227+
228+
MI_DECLARE_TRAVERSE_CB()
227229
};
228230

229231
MI_EXTERN_CLASS(Film)

include/mitsuba/render/imageblock.h

+2
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,8 @@ class MI_EXPORT_LIB ImageBlock : public Object {
346346
bool m_compensate;
347347
bool m_warn_negative;
348348
bool m_warn_invalid;
349+
350+
DR_TRAVERSE_CB(Object, m_tensor, m_tensor_compensation)
349351
};
350352

351353
MI_EXTERN_CLASS(ImageBlock)

0 commit comments

Comments
 (0)