Skip to content

Commit

Permalink
don't limit the number of records in a record batch when decoding
Browse files Browse the repository at this point in the history
created `getArrayLengthNoLimit` as record batches can contain more
than `2*math.MaxUint16 records`. The packet decoder will make sure
that the array length isn't greater than the number of bytes remaining
to be decoding in the packet.

Also added a test for large record counts.

Fixes: #3119
Signed-off-by: Ryan Belgrave <[email protected]>
  • Loading branch information
rmb938 committed Mar 5, 2025
1 parent c7ca87e commit bc2f7cc
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 1 deletion.
1 change: 1 addition & 0 deletions packet_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type packetDecoder interface {
getUVarint() (uint64, error)
getFloat64() (float64, error)
getArrayLength() (int, error)
getArrayLengthNoLimit() (int, error)
getCompactArrayLength() (int, error)
getBool() (bool, error)
getEmptyTaggedFieldArray() (int, error)
Expand Down
14 changes: 14 additions & 0 deletions real_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,20 @@ func (rd *realDecoder) getArrayLength() (int, error) {
return tmp, nil
}

func (rd *realDecoder) getArrayLengthNoLimit() (int, error) {
if rd.remaining() < 4 {
rd.off = len(rd.raw)
return -1, ErrInsufficientData
}
tmp := int(int32(binary.BigEndian.Uint32(rd.raw[rd.off:])))
rd.off += 4
if tmp > rd.remaining() {
rd.off = len(rd.raw)
return -1, ErrInsufficientData
}
return tmp, nil
}

func (rd *realDecoder) getCompactArrayLength() (int, error) {
n, err := rd.getUVarint()
if err != nil {
Expand Down
6 changes: 5 additions & 1 deletion record_batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,11 @@ func (b *RecordBatch) decode(pd packetDecoder) (err error) {
return err
}

numRecs, err := pd.getArrayLength()
// Using NoLimit because a single record batch could contain
// more then 2*math.MaxUint16 records. The packet decoder will
// check to make sure the array is not greater than the
// remaining bytes.
numRecs, err := pd.getArrayLengthNoLimit()
if err != nil {
return err
}
Expand Down
38 changes: 38 additions & 0 deletions record_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package sarama

import (
"fmt"
"reflect"
"testing"
"time"
Expand Down Expand Up @@ -254,3 +255,40 @@ func TestRecordBatchDecoding(t *testing.T) {
}
}
}

func TestRecordBatchInvalidNumRecords(t *testing.T) {
encodedBatch := []byte{
0, 0, 0, 0, 0, 0, 0, 0, // First Offset
0, 0, 0, 70, // Length
0, 0, 0, 0, // Partition Leader Epoch
2, // Version
91, 48, 202, 99, // CRC
0, 0, // Attributes
0, 0, 0, 0, // Last Offset Delta
0, 0, 1, 88, 141, 205, 89, 56, // First Timestamp
0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp
0, 0, 0, 0, 0, 0, 0, 0, // Producer ID
0, 0, // Producer Epoch
0, 0, 0, 0, // First Sequence
0, 1, 255, 255, // Number of Records - 1 + 2*math.MaxUint16
40, // Record Length
0, // Attributes
10, // Timestamp Delta
0, // Offset Delta
8, // Key Length
1, 2, 3, 4,
6, // Value Length
5, 6, 7,
2, // Number of Headers
6, // Header Key Length
8, 9, 10, // Header Key
4, // Header Value Length
11, 12, // Header Value
}

batch := RecordBatch{}
err := decode(encodedBatch, &batch, nil)
if err != ErrInsufficientData {
t.Fatal(fmt.Errorf("was suppose to get ErrInsufficientData, instead got: %w", err))
}
}

0 comments on commit bc2f7cc

Please sign in to comment.