Skip to content

Commit

Permalink
Keep track of field count in the corpus_type for protobuf domains.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715013770
  • Loading branch information
hadi88 authored and copybara-github committed Jan 13, 2025
1 parent 174c1ff commit 2135807
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 9 deletions.
2 changes: 1 addition & 1 deletion fuzztest/internal/domains/container_of_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ class SequenceContainerOfImplBase
return val;
}

uint64_t CountNumberOfFields(const corpus_type& val) {
uint64_t CountNumberOfFields(corpus_type& val) {
uint64_t total_weight = 0;
for (auto& i : val) {
total_weight += this->inner_.CountNumberOfFields(i);
Expand Down
5 changes: 3 additions & 2 deletions fuzztest/internal/domains/domain.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,12 @@ class Domain {

// Return the field counts of `corpus_value` if `corpus_value` is
// a `ProtobufDomainImpl::corpus_type`. Otherwise propagate it
// to inner domains and returns the sum of inner results.
// to inner domains and returns the sum of inner results. The corpus value is
// taken as mutable reference to allow memoization.
//
// TODO(b/303324603): Using an extension mechanism, expose this method in
// the interface only for user value types `T` for which it makes sense.
uint64_t CountNumberOfFields(const corpus_type& corpus_value) {
uint64_t CountNumberOfFields(corpus_type& corpus_value) {
return inner_->UntypedCountNumberOfFields(corpus_value);
}

Expand Down
5 changes: 2 additions & 3 deletions fuzztest/internal/domains/domain_type_erasure.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ class UntypedDomainConcept {
const GenericDomainCorpusType& corpus_value) const = 0;
virtual IRObject UntypedSerializeCorpus(
const GenericDomainCorpusType& v) const = 0;
virtual uint64_t UntypedCountNumberOfFields(
const GenericDomainCorpusType&) = 0;
virtual uint64_t UntypedCountNumberOfFields(GenericDomainCorpusType&) = 0;
virtual uint64_t UntypedMutateSelectedField(
GenericDomainCorpusType&, absl::BitGenRef,
const domain_implementor::MutationMetadata&, bool, uint64_t) = 0;
Expand Down Expand Up @@ -188,7 +187,7 @@ class DomainModel final : public TypedDomainConcept<value_type_t<D>> {
return domain_.ValidateCorpusValue(corpus_value.GetAs<CorpusType>());
}

uint64_t UntypedCountNumberOfFields(const GenericDomainCorpusType& v) final {
uint64_t UntypedCountNumberOfFields(GenericDomainCorpusType& v) final {
return domain_.CountNumberOfFields(v.GetAs<CorpusType>());
}

Expand Down
2 changes: 1 addition & 1 deletion fuzztest/internal/domains/optional_of_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class OptionalOfImpl
return *this;
}

uint64_t CountNumberOfFields(const corpus_type& val) {
uint64_t CountNumberOfFields(corpus_type& val) {
if (val.index() == 1) {
return inner_.CountNumberOfFields(std::get<1>(val));
}
Expand Down
22 changes: 20 additions & 2 deletions fuzztest/internal/domains/protobuf_domain_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ class ProtobufDomainUntypedImpl
value_type out(prototype_.Get()->New());

for (auto& [number, data] : value) {
if (IsMetadataEntry(number)) continue;
auto* field = GetField(number);
VisitProtobufField(field, GetValueVisitor{*out, *this, data});
}
Expand Down Expand Up @@ -573,6 +574,7 @@ class ProtobufDomainUntypedImpl
auto& subs = out.MutableSubs();
subs.reserve(v.size());
for (auto& [number, inner] : v) {
if (IsMetadataEntry(number)) continue;
auto* field = GetField(number);
FUZZTEST_INTERNAL_CHECK(field, "Field not found by number: ", number);
IRObject& pair = subs.emplace_back();
Expand All @@ -585,7 +587,10 @@ class ProtobufDomainUntypedImpl
return out;
}

uint64_t CountNumberOfFields(const corpus_type& val) {
uint64_t CountNumberOfFields(corpus_type& val) {
if (auto it = val.find(kFieldCountIndex); it != val.end()) {
return it->second.template GetAs<uint64_t>();
}
uint64_t total_weight = 0;
auto descriptor = prototype_.Get()->GetDescriptor();
if (GetFieldCount(descriptor) == 0) return total_weight;
Expand All @@ -611,6 +616,8 @@ class ProtobufDomainUntypedImpl
}
}
}
val[kFieldCountIndex] =
GenericDomainCorpusType(std::in_place_type<uint64_t>, total_weight);
return total_weight;
}

Expand All @@ -621,6 +628,9 @@ class ProtobufDomainUntypedImpl
uint64_t field_counter = 0;
auto descriptor = prototype_.Get()->GetDescriptor();
if (GetFieldCount(descriptor) == 0) return field_counter;
int64_t fields_count = CountNumberOfFields(val);
if (fields_count < selected_field_index) return fields_count;
val.erase(kFieldCountIndex); // Mutation invalidates the cache value.

for (const FieldDescriptor* field : GetProtobufFields(descriptor)) {
if (field->containing_oneof() &&
Expand Down Expand Up @@ -1695,6 +1705,14 @@ class ProtobufDomainUntypedImpl
return result;
}

// corpus_type is a map from field number to values. number -1 is reserved for
// storing the field count.
static constexpr int64_t kFieldCountIndex = -1;

static bool IsMetadataEntry(int64_t index) {
return index == kFieldCountIndex;
}

bool IsOneofRecursive(const OneofDescriptor* oneof,
absl::flat_hash_set<const Descriptor*>& parents,
bool consider_non_terminating_recursions) const {
Expand Down Expand Up @@ -1828,7 +1846,7 @@ class ProtobufDomainImpl
return inner_.Init(prng);
}

uint64_t CountNumberOfFields(const corpus_type& val) {
uint64_t CountNumberOfFields(corpus_type& val) {
return inner_.CountNumberOfFields(val);
}

Expand Down

0 comments on commit 2135807

Please sign in to comment.