Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for field name based composite type scanning #2230

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions pgtype/builtin_wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,25 @@ func (w *ptrStructWrapper) ScanIndex(i int) any {
return w.exportedFields[i].Addr().Interface()
}

// ptrStructWrapper implements CompositeIndexScanner for a pointer to a struct.
type ptrStructNameWrapper struct {
s any
exportedFields map[string]reflect.Value
}

func (w *ptrStructNameWrapper) ScanNull() error {
return fmt.Errorf("cannot scan NULL into %#v", w.s)
}

func (w *ptrStructNameWrapper) ScanName(n string) any {
value, ok := w.exportedFields[n]
if !ok {
return fmt.Errorf("%#v only has %d public fields - %s is not a field", w.s, len(w.exportedFields), n)
}

return value.Addr().Interface()
}

type anySliceArrayReflect struct {
slice reflect.Value
}
Expand Down
91 changes: 91 additions & 0 deletions pgtype/composite.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ type CompositeIndexScanner interface {
ScanIndex(i int) any
}

// CompositeNameScanner is a type accessed by name that can be scanned from a PostgreSQL composite.
type CompositeNameScanner interface {
// ScanNull sets the value to SQL NULL.
ScanNull() error

// ScanName returns a value usable as a scan target
ScanName(n string) any
}

type CompositeCodecField struct {
Name string
Type *Type
Expand Down Expand Up @@ -115,11 +124,15 @@ func (c *CompositeCodec) PlanScan(m *Map, oid uint32, format int16, target any)
switch target.(type) {
case CompositeIndexScanner:
return &scanPlanBinaryCompositeToCompositeIndexScanner{cc: c, m: m}
case CompositeNameScanner:
return &scanPlanBinaryCompositeToCompositeNameScanner{cc: c, m: m}
}
case TextFormatCode:
switch target.(type) {
case CompositeIndexScanner:
return &scanPlanTextCompositeToCompositeIndexScanner{cc: c, m: m}
case CompositeNameScanner:
return &scanPlanTextCompositeToCompositeNameScanner{cc: c, m: m}
}
}

Expand Down Expand Up @@ -165,6 +178,45 @@ func (plan *scanPlanBinaryCompositeToCompositeIndexScanner) Scan(src []byte, tar
return nil
}

type scanPlanBinaryCompositeToCompositeNameScanner struct {
cc *CompositeCodec
m *Map
}

func (plan *scanPlanBinaryCompositeToCompositeNameScanner) Scan(src []byte, target any) error {
targetScanner := (target).(CompositeNameScanner)

if src == nil {
return targetScanner.ScanNull()
}

scanner := NewCompositeBinaryScanner(plan.m, src)
for _, field := range plan.cc.Fields {
if scanner.Next() {
fieldTarget := targetScanner.ScanName(field.Name)
if fieldTarget != nil {
fieldPlan := plan.m.PlanScan(field.Type.OID, BinaryFormatCode, fieldTarget)
if fieldPlan == nil {
return fmt.Errorf("unable to encode %v into OID %d in binary format", field, field.Type.OID)
}

err := fieldPlan.Scan(scanner.Bytes(), fieldTarget)
if err != nil {
return err
}
}
} else {
return errors.New("read past end of composite")
}
}

if err := scanner.Err(); err != nil {
return err
}

return nil
}

type scanPlanTextCompositeToCompositeIndexScanner struct {
cc *CompositeCodec
m *Map
Expand Down Expand Up @@ -204,6 +256,45 @@ func (plan *scanPlanTextCompositeToCompositeIndexScanner) Scan(src []byte, targe
return nil
}

type scanPlanTextCompositeToCompositeNameScanner struct {
cc *CompositeCodec
m *Map
}

func (plan *scanPlanTextCompositeToCompositeNameScanner) Scan(src []byte, target any) error {
targetScanner := (target).(CompositeNameScanner)

if src == nil {
return targetScanner.ScanNull()
}

scanner := NewCompositeTextScanner(plan.m, src)
for _, field := range plan.cc.Fields {
if scanner.Next() {
fieldTarget := targetScanner.ScanName(field.Name)
if fieldTarget != nil {
fieldPlan := plan.m.PlanScan(field.Type.OID, TextFormatCode, fieldTarget)
if fieldPlan == nil {
return fmt.Errorf("unable to encode %v into OID %d in text format", field, field.Type.OID)
}

err := fieldPlan.Scan(scanner.Bytes(), fieldTarget)
if err != nil {
return err
}
}
} else {
return errors.New("read past end of composite")
}
}

if err := scanner.Err(); err != nil {
return err
}

return nil
}

func (c *CompositeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil {
return nil, nil
Expand Down
67 changes: 67 additions & 0 deletions pgtype/pgtype.go
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,52 @@ func TryWrapStructScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValu
return nil, nil, false
}

// TryWrapStructFieldNameScanPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter.
func TryWrapStructFieldNameScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValue any, ok bool) {
targetValue := reflect.ValueOf(target)
if targetValue.Kind() != reflect.Ptr {
return nil, nil, false
}

var targetElemValue reflect.Value
if targetValue.IsNil() {
targetElemValue = reflect.Zero(targetValue.Type().Elem())
} else {
targetElemValue = targetValue.Elem()
}
targetElemType := targetElemValue.Type()

if targetElemType.Kind() == reflect.Struct {
exportedFields := getExportedFieldNameValues(targetElemValue)
if len(exportedFields) == 0 {
return nil, nil, false
}

w := ptrStructNameWrapper{
s: target,
exportedFields: exportedFields,
}
return &wrapAnyPtrStructFieldNameScanPlan{}, &w, true
}

return nil, nil, false
}

type wrapAnyPtrStructFieldNameScanPlan struct {
next ScanPlan
}

func (plan *wrapAnyPtrStructFieldNameScanPlan) SetNext(next ScanPlan) { plan.next = next }

func (plan *wrapAnyPtrStructFieldNameScanPlan) Scan(src []byte, target any) error {
w := ptrStructNameWrapper{
s: target,
exportedFields: getExportedFieldNameValues(reflect.ValueOf(target).Elem()),
}

return plan.next.Scan(src, &w)
}

type wrapAnyPtrStructScanPlan struct {
next ScanPlan
}
Expand Down Expand Up @@ -1842,6 +1888,27 @@ func getExportedFieldValues(structValue reflect.Value) []reflect.Value {
return exportedFields
}

func getExportedFieldNameValues(structValue reflect.Value) map[string]reflect.Value {
structType := structValue.Type()
exportedFields := make(map[string]reflect.Value, structValue.NumField())
for i := 0; i < structType.NumField(); i++ {
sf := structType.Field(i)
if sf.IsExported() {
name := sf.Name
value, ok := sf.Tag.Lookup("db")
if ok {
if value == "-" {
continue
}
name = value
}
exportedFields[name] = structValue.Field(i)
}
}

return exportedFields
}

func TryWrapSliceEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) {
if _, ok := value.(driver.Valuer); ok {
return nil, nil, false
Expand Down