Skip to content

Commit

Permalink
codec: fix parsing of optional values, add a special type codec for…
Browse files Browse the repository at this point in the history
… OnRampAddress (#1109)

* codec: Add the ability to customize parsing for specific types

* codec: Properly parse option types
  • Loading branch information
archseer authored Feb 27, 2025
1 parent a21ba8f commit 8fcdaa6
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 4 deletions.
2 changes: 2 additions & 0 deletions pkg/solana/codec/anchoridl.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ type IdlTypeDefTyKind string
const (
IdlTypeDefTyKindStruct IdlTypeDefTyKind = "struct"
IdlTypeDefTyKindEnum IdlTypeDefTyKind = "enum"
IdlTypeDefTyKindCustom IdlTypeDefTyKind = "custom"
)

type IdlTypeDefTyStruct struct {
Expand All @@ -380,6 +381,7 @@ type IdlTypeDefTy struct {

Fields *IdlTypeDefStruct `json:"fields,omitempty"`
Variants IdlEnumVariantSlice `json:"variants,omitempty"`
Codec string `json:"codec,omitempty"`
}

type IdlEnumVariantSlice []IdlEnumVariant
Expand Down
72 changes: 72 additions & 0 deletions pkg/solana/codec/onramp_address.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package codec

import (
"fmt"
"reflect"

"github.com/smartcontractkit/chainlink-common/pkg/codec/encodings"
"github.com/smartcontractkit/chainlink-common/pkg/types"
)

func NewOnRampAddress(builder encodings.Builder) encodings.TypeCodec {
return &onRampAddress{
intEncoder: builder.Uint32(),
}
}

type onRampAddress struct {
intEncoder encodings.TypeCodec
}

var _ encodings.TypeCodec = &onRampAddress{}

func (d *onRampAddress) Encode(value any, into []byte) ([]byte, error) {
bi, ok := value.([]byte)
if !ok {
return nil, fmt.Errorf("%w: expected []byte, got %T", types.ErrInvalidType, value)
}

length := len(bi)
if length > 64 {
return nil, fmt.Errorf("%w: expected []byte to be 64 bytes or less, got %v", types.ErrInvalidType, length)
}
// assert 64 bytes or less
var buf [64]byte
copy(buf[:], bi)

// 64 bytes, padded, then len u32
into = append(into, buf[:]...)
return d.intEncoder.Encode(uint32(length), into)
}

func (d *onRampAddress) Decode(encoded []byte) (any, []byte, error) {
buf := encoded[0:64]
encoded = encoded[64:]

// decode uint32 len
l, bytes, err := d.intEncoder.Decode(encoded)
if err != nil {
return nil, bytes, err
}

length, ok := l.(uint32)
if !ok {
return nil, bytes, fmt.Errorf("expected uint32, got %T", l)
}

return buf[:length], bytes, nil
}

func (d *onRampAddress) GetType() reflect.Type {
return reflect.TypeOf([]byte{})
}

func (d *onRampAddress) Size(val int) (int, error) {
// 64 bytes + uint32
return 64 + 4, nil
}

func (d *onRampAddress) FixedSize() (int, error) {
// 64 bytes + uint32
return 64 + 4, nil
}
59 changes: 59 additions & 0 deletions pkg/solana/codec/option.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package codec

import (
"fmt"
"reflect"

"github.com/smartcontractkit/chainlink-common/pkg/codec/encodings"
commonencodings "github.com/smartcontractkit/chainlink-common/pkg/codec/encodings"
)

func NewOption(codec commonencodings.TypeCodec) encodings.TypeCodec {
return &option{
codec,
}
}

type option struct {
codec encodings.TypeCodec
}

var _ encodings.TypeCodec = &option{}

func (d *option) Encode(value any, into []byte) ([]byte, error) {
// encoding is either 0 for None, or 1, bytes... for Some(val)
if value == nil {
return append(into, 0), nil
}

into = append(into, 1)
return d.codec.Encode(value, into)
}

func (d *option) Decode(encoded []byte) (any, []byte, error) {
prefix := encoded[0]
bytes := encoded[1:]

// encoding is either 0 for None, or 1, bytes... for Some(val)
if prefix == 0 {
return reflect.Zero(d.codec.GetType()).Interface(), encoded[1:], nil
}

if prefix != 1 {
return nil, encoded, fmt.Errorf("expected either 0 or 1, got %v", prefix)
}

return d.codec.Decode(bytes)
}

func (d *option) GetType() reflect.Type {
return d.codec.GetType()
}

func (d *option) Size(val int) (int, error) {
return d.codec.Size(val)
}

func (d *option) FixedSize() (int, error) {
return d.codec.FixedSize()
}
11 changes: 9 additions & 2 deletions pkg/solana/codec/solana.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,13 @@ func createCodecType(
return name, nil, fmt.Errorf("%w: variants are not supported", commontypes.ErrInvalidConfig)
}
return name, refs.builder.Uint8(), nil
case IdlTypeDefTyKindCustom:
switch def.Type.Codec {
case "onramp_address":
return name, NewOnRampAddress(refs.builder), nil
default:
return name, nil, fmt.Errorf(unknownIDLFormat, commontypes.ErrInvalidConfig, def.Type.Codec)
}
default:
return name, nil, fmt.Errorf(unknownIDLFormat, commontypes.ErrInvalidConfig, def.Type.Kind)
}
Expand Down Expand Up @@ -301,8 +308,8 @@ func processFieldType(parentTypeName string, idlType IdlType, refs *codecRefs) (
return getCodecByStringType(idlType.GetString(), refs.builder)
case idlType.IsIdlTypeOption():
// Go doesn't have an `Option` type; use pointer to type instead
// this should be automatic in the codec
return processFieldType(parentTypeName, idlType.GetIdlTypeOption().Option, refs)
inner, err := processFieldType(parentTypeName, idlType.GetIdlTypeOption().Option, refs)
return NewOption(inner), err
case idlType.IsIdlTypeDefined():
return asDefined(parentTypeName, idlType.GetIdlTypeDefined(), refs)
case idlType.IsArray():
Expand Down
4 changes: 2 additions & 2 deletions pkg/solana/codec/solana_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestNewIDLAccountCodec(t *testing.T) {
bts, err := entry.Encode(ctx, expected, testutils.TestStructWithNestedStruct)

// length of fields + discriminator
require.Equal(t, 262, len(bts))
require.Equal(t, 263, len(bts))
require.NoError(t, err)

var decoded testutils.StructWithNestedStruct
Expand Down Expand Up @@ -73,7 +73,7 @@ func TestNewIDLDefinedTypesCodecCodec(t *testing.T) {
bts, err := entry.Encode(ctx, expected, testutils.TestStructWithNestedStructType)

// length of fields without a discriminator
require.Equal(t, 254, len(bts))
require.Equal(t, 255, len(bts))

require.NoError(t, err)

Expand Down

0 comments on commit 8fcdaa6

Please sign in to comment.