From 21358077ccdcce9dfee1a6aeb25f967d2b5705cf Mon Sep 17 00:00:00 2001 From: Hadi Ravanbakhsh Date: Mon, 13 Jan 2025 10:07:13 -0800 Subject: [PATCH] Keep track of field count in the corpus_type for protobuf domains. PiperOrigin-RevId: 715013770 --- fuzztest/internal/domains/container_of_impl.h | 2 +- fuzztest/internal/domains/domain.h | 5 +++-- .../internal/domains/domain_type_erasure.h | 5 ++--- fuzztest/internal/domains/optional_of_impl.h | 2 +- .../internal/domains/protobuf_domain_impl.h | 22 +++++++++++++++++-- 5 files changed, 27 insertions(+), 9 deletions(-) diff --git a/fuzztest/internal/domains/container_of_impl.h b/fuzztest/internal/domains/container_of_impl.h index 3ef62f8b..02708d90 100644 --- a/fuzztest/internal/domains/container_of_impl.h +++ b/fuzztest/internal/domains/container_of_impl.h @@ -547,7 +547,7 @@ class SequenceContainerOfImplBase return val; } - uint64_t CountNumberOfFields(const corpus_type& val) { + uint64_t CountNumberOfFields(corpus_type& val) { uint64_t total_weight = 0; for (auto& i : val) { total_weight += this->inner_.CountNumberOfFields(i); diff --git a/fuzztest/internal/domains/domain.h b/fuzztest/internal/domains/domain.h index 6faa9fc8..43492471 100644 --- a/fuzztest/internal/domains/domain.h +++ b/fuzztest/internal/domains/domain.h @@ -251,11 +251,12 @@ class Domain { // Return the field counts of `corpus_value` if `corpus_value` is // a `ProtobufDomainImpl::corpus_type`. Otherwise propagate it - // to inner domains and returns the sum of inner results. + // to inner domains and returns the sum of inner results. The corpus value is + // taken as mutable reference to allow memoization. // // TODO(b/303324603): Using an extension mechanism, expose this method in // the interface only for user value types `T` for which it makes sense. - uint64_t CountNumberOfFields(const corpus_type& corpus_value) { + uint64_t CountNumberOfFields(corpus_type& corpus_value) { return inner_->UntypedCountNumberOfFields(corpus_value); } diff --git a/fuzztest/internal/domains/domain_type_erasure.h b/fuzztest/internal/domains/domain_type_erasure.h index c2dfb83c..2c38c3f9 100644 --- a/fuzztest/internal/domains/domain_type_erasure.h +++ b/fuzztest/internal/domains/domain_type_erasure.h @@ -70,8 +70,7 @@ class UntypedDomainConcept { const GenericDomainCorpusType& corpus_value) const = 0; virtual IRObject UntypedSerializeCorpus( const GenericDomainCorpusType& v) const = 0; - virtual uint64_t UntypedCountNumberOfFields( - const GenericDomainCorpusType&) = 0; + virtual uint64_t UntypedCountNumberOfFields(GenericDomainCorpusType&) = 0; virtual uint64_t UntypedMutateSelectedField( GenericDomainCorpusType&, absl::BitGenRef, const domain_implementor::MutationMetadata&, bool, uint64_t) = 0; @@ -188,7 +187,7 @@ class DomainModel final : public TypedDomainConcept> { return domain_.ValidateCorpusValue(corpus_value.GetAs()); } - uint64_t UntypedCountNumberOfFields(const GenericDomainCorpusType& v) final { + uint64_t UntypedCountNumberOfFields(GenericDomainCorpusType& v) final { return domain_.CountNumberOfFields(v.GetAs()); } diff --git a/fuzztest/internal/domains/optional_of_impl.h b/fuzztest/internal/domains/optional_of_impl.h index f5dde1eb..e80b93d9 100644 --- a/fuzztest/internal/domains/optional_of_impl.h +++ b/fuzztest/internal/domains/optional_of_impl.h @@ -150,7 +150,7 @@ class OptionalOfImpl return *this; } - uint64_t CountNumberOfFields(const corpus_type& val) { + uint64_t CountNumberOfFields(corpus_type& val) { if (val.index() == 1) { return inner_.CountNumberOfFields(std::get<1>(val)); } diff --git a/fuzztest/internal/domains/protobuf_domain_impl.h b/fuzztest/internal/domains/protobuf_domain_impl.h index e4ad00f0..acd1ef12 100644 --- a/fuzztest/internal/domains/protobuf_domain_impl.h +++ b/fuzztest/internal/domains/protobuf_domain_impl.h @@ -504,6 +504,7 @@ class ProtobufDomainUntypedImpl value_type out(prototype_.Get()->New()); for (auto& [number, data] : value) { + if (IsMetadataEntry(number)) continue; auto* field = GetField(number); VisitProtobufField(field, GetValueVisitor{*out, *this, data}); } @@ -573,6 +574,7 @@ class ProtobufDomainUntypedImpl auto& subs = out.MutableSubs(); subs.reserve(v.size()); for (auto& [number, inner] : v) { + if (IsMetadataEntry(number)) continue; auto* field = GetField(number); FUZZTEST_INTERNAL_CHECK(field, "Field not found by number: ", number); IRObject& pair = subs.emplace_back(); @@ -585,7 +587,10 @@ class ProtobufDomainUntypedImpl return out; } - uint64_t CountNumberOfFields(const corpus_type& val) { + uint64_t CountNumberOfFields(corpus_type& val) { + if (auto it = val.find(kFieldCountIndex); it != val.end()) { + return it->second.template GetAs(); + } uint64_t total_weight = 0; auto descriptor = prototype_.Get()->GetDescriptor(); if (GetFieldCount(descriptor) == 0) return total_weight; @@ -611,6 +616,8 @@ class ProtobufDomainUntypedImpl } } } + val[kFieldCountIndex] = + GenericDomainCorpusType(std::in_place_type, total_weight); return total_weight; } @@ -621,6 +628,9 @@ class ProtobufDomainUntypedImpl uint64_t field_counter = 0; auto descriptor = prototype_.Get()->GetDescriptor(); if (GetFieldCount(descriptor) == 0) return field_counter; + int64_t fields_count = CountNumberOfFields(val); + if (fields_count < selected_field_index) return fields_count; + val.erase(kFieldCountIndex); // Mutation invalidates the cache value. for (const FieldDescriptor* field : GetProtobufFields(descriptor)) { if (field->containing_oneof() && @@ -1695,6 +1705,14 @@ class ProtobufDomainUntypedImpl return result; } + // corpus_type is a map from field number to values. number -1 is reserved for + // storing the field count. + static constexpr int64_t kFieldCountIndex = -1; + + static bool IsMetadataEntry(int64_t index) { + return index == kFieldCountIndex; + } + bool IsOneofRecursive(const OneofDescriptor* oneof, absl::flat_hash_set& parents, bool consider_non_terminating_recursions) const { @@ -1828,7 +1846,7 @@ class ProtobufDomainImpl return inner_.Init(prng); } - uint64_t CountNumberOfFields(const corpus_type& val) { + uint64_t CountNumberOfFields(corpus_type& val) { return inner_.CountNumberOfFields(val); }