From bc2f7cc3a8f0d8fbaea1d49930c390b77b38e968 Mon Sep 17 00:00:00 2001 From: Ryan Belgrave Date: Wed, 5 Mar 2025 15:35:12 -0600 Subject: [PATCH] don't limit the number of records in a record batch when decoding created `getArrayLengthNoLimit` as record batches can contain more than `2*math.MaxUint16 records`. The packet decoder will make sure that the array length isn't greater than the number of bytes remaining to be decoding in the packet. Also added a test for large record counts. Fixes: #3119 Signed-off-by: Ryan Belgrave --- packet_decoder.go | 1 + real_decoder.go | 14 ++++++++++++++ record_batch.go | 6 +++++- record_test.go | 38 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 58 insertions(+), 1 deletion(-) diff --git a/packet_decoder.go b/packet_decoder.go index 526e0f42f..a459bc79a 100644 --- a/packet_decoder.go +++ b/packet_decoder.go @@ -15,6 +15,7 @@ type packetDecoder interface { getUVarint() (uint64, error) getFloat64() (float64, error) getArrayLength() (int, error) + getArrayLengthNoLimit() (int, error) getCompactArrayLength() (int, error) getBool() (bool, error) getEmptyTaggedFieldArray() (int, error) diff --git a/real_decoder.go b/real_decoder.go index 7e37641f9..12f0d6e74 100644 --- a/real_decoder.go +++ b/real_decoder.go @@ -121,6 +121,20 @@ func (rd *realDecoder) getArrayLength() (int, error) { return tmp, nil } +func (rd *realDecoder) getArrayLengthNoLimit() (int, error) { + if rd.remaining() < 4 { + rd.off = len(rd.raw) + return -1, ErrInsufficientData + } + tmp := int(int32(binary.BigEndian.Uint32(rd.raw[rd.off:]))) + rd.off += 4 + if tmp > rd.remaining() { + rd.off = len(rd.raw) + return -1, ErrInsufficientData + } + return tmp, nil +} + func (rd *realDecoder) getCompactArrayLength() (int, error) { n, err := rd.getUVarint() if err != nil { diff --git a/record_batch.go b/record_batch.go index c422c5c2f..0100e746a 100644 --- a/record_batch.go +++ b/record_batch.go @@ -157,7 +157,11 @@ func (b *RecordBatch) decode(pd packetDecoder) (err error) { return err } - numRecs, err := pd.getArrayLength() + // Using NoLimit because a single record batch could contain + // more then 2*math.MaxUint16 records. The packet decoder will + // check to make sure the array is not greater than the + // remaining bytes. + numRecs, err := pd.getArrayLengthNoLimit() if err != nil { return err } diff --git a/record_test.go b/record_test.go index b8c48fcab..2afc265d7 100644 --- a/record_test.go +++ b/record_test.go @@ -3,6 +3,7 @@ package sarama import ( + "fmt" "reflect" "testing" "time" @@ -254,3 +255,40 @@ func TestRecordBatchDecoding(t *testing.T) { } } } + +func TestRecordBatchInvalidNumRecords(t *testing.T) { + encodedBatch := []byte{ + 0, 0, 0, 0, 0, 0, 0, 0, // First Offset + 0, 0, 0, 70, // Length + 0, 0, 0, 0, // Partition Leader Epoch + 2, // Version + 91, 48, 202, 99, // CRC + 0, 0, // Attributes + 0, 0, 0, 0, // Last Offset Delta + 0, 0, 1, 88, 141, 205, 89, 56, // First Timestamp + 0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp + 0, 0, 0, 0, 0, 0, 0, 0, // Producer ID + 0, 0, // Producer Epoch + 0, 0, 0, 0, // First Sequence + 0, 1, 255, 255, // Number of Records - 1 + 2*math.MaxUint16 + 40, // Record Length + 0, // Attributes + 10, // Timestamp Delta + 0, // Offset Delta + 8, // Key Length + 1, 2, 3, 4, + 6, // Value Length + 5, 6, 7, + 2, // Number of Headers + 6, // Header Key Length + 8, 9, 10, // Header Key + 4, // Header Value Length + 11, 12, // Header Value + } + + batch := RecordBatch{} + err := decode(encodedBatch, &batch, nil) + if err != ErrInsufficientData { + t.Fatal(fmt.Errorf("was suppose to get ErrInsufficientData, instead got: %w", err)) + } +}