diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index b39d3fa10..3f8ab3230 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -743,6 +743,25 @@ func (w *ptrStructWrapper) ScanIndex(i int) any { return w.exportedFields[i].Addr().Interface() } +// ptrStructWrapper implements CompositeIndexScanner for a pointer to a struct. +type ptrStructNameWrapper struct { + s any + exportedFields map[string]reflect.Value +} + +func (w *ptrStructNameWrapper) ScanNull() error { + return fmt.Errorf("cannot scan NULL into %#v", w.s) +} + +func (w *ptrStructNameWrapper) ScanName(n string) any { + value, ok := w.exportedFields[n] + if !ok { + return fmt.Errorf("%#v only has %d public fields - %s is not a field", w.s, len(w.exportedFields), n) + } + + return value.Addr().Interface() +} + type anySliceArrayReflect struct { slice reflect.Value } diff --git a/pgtype/composite.go b/pgtype/composite.go index fb372325b..83aacd52f 100644 --- a/pgtype/composite.go +++ b/pgtype/composite.go @@ -28,6 +28,15 @@ type CompositeIndexScanner interface { ScanIndex(i int) any } +// CompositeNameScanner is a type accessed by name that can be scanned from a PostgreSQL composite. +type CompositeNameScanner interface { + // ScanNull sets the value to SQL NULL. + ScanNull() error + + // ScanName returns a value usable as a scan target + ScanName(n string) any +} + type CompositeCodecField struct { Name string Type *Type @@ -115,11 +124,15 @@ func (c *CompositeCodec) PlanScan(m *Map, oid uint32, format int16, target any) switch target.(type) { case CompositeIndexScanner: return &scanPlanBinaryCompositeToCompositeIndexScanner{cc: c, m: m} + case CompositeNameScanner: + return &scanPlanBinaryCompositeToCompositeNameScanner{cc: c, m: m} } case TextFormatCode: switch target.(type) { case CompositeIndexScanner: return &scanPlanTextCompositeToCompositeIndexScanner{cc: c, m: m} + case CompositeNameScanner: + return &scanPlanTextCompositeToCompositeNameScanner{cc: c, m: m} } } @@ -165,6 +178,45 @@ func (plan *scanPlanBinaryCompositeToCompositeIndexScanner) Scan(src []byte, tar return nil } +type scanPlanBinaryCompositeToCompositeNameScanner struct { + cc *CompositeCodec + m *Map +} + +func (plan *scanPlanBinaryCompositeToCompositeNameScanner) Scan(src []byte, target any) error { + targetScanner := (target).(CompositeNameScanner) + + if src == nil { + return targetScanner.ScanNull() + } + + scanner := NewCompositeBinaryScanner(plan.m, src) + for _, field := range plan.cc.Fields { + if scanner.Next() { + fieldTarget := targetScanner.ScanName(field.Name) + if fieldTarget != nil { + fieldPlan := plan.m.PlanScan(field.Type.OID, BinaryFormatCode, fieldTarget) + if fieldPlan == nil { + return fmt.Errorf("unable to encode %v into OID %d in binary format", field, field.Type.OID) + } + + err := fieldPlan.Scan(scanner.Bytes(), fieldTarget) + if err != nil { + return err + } + } + } else { + return errors.New("read past end of composite") + } + } + + if err := scanner.Err(); err != nil { + return err + } + + return nil +} + type scanPlanTextCompositeToCompositeIndexScanner struct { cc *CompositeCodec m *Map @@ -204,6 +256,45 @@ func (plan *scanPlanTextCompositeToCompositeIndexScanner) Scan(src []byte, targe return nil } +type scanPlanTextCompositeToCompositeNameScanner struct { + cc *CompositeCodec + m *Map +} + +func (plan *scanPlanTextCompositeToCompositeNameScanner) Scan(src []byte, target any) error { + targetScanner := (target).(CompositeNameScanner) + + if src == nil { + return targetScanner.ScanNull() + } + + scanner := NewCompositeTextScanner(plan.m, src) + for _, field := range plan.cc.Fields { + if scanner.Next() { + fieldTarget := targetScanner.ScanName(field.Name) + if fieldTarget != nil { + fieldPlan := plan.m.PlanScan(field.Type.OID, TextFormatCode, fieldTarget) + if fieldPlan == nil { + return fmt.Errorf("unable to encode %v into OID %d in text format", field, field.Type.OID) + } + + err := fieldPlan.Scan(scanner.Bytes(), fieldTarget) + if err != nil { + return err + } + } + } else { + return errors.New("read past end of composite") + } + } + + if err := scanner.Err(); err != nil { + return err + } + + return nil +} + func (c *CompositeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 20645d694..d715e7685 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -963,6 +963,52 @@ func TryWrapStructScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValu return nil, nil, false } +// TryWrapStructFieldNameScanPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter. +func TryWrapStructFieldNameScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValue any, ok bool) { + targetValue := reflect.ValueOf(target) + if targetValue.Kind() != reflect.Ptr { + return nil, nil, false + } + + var targetElemValue reflect.Value + if targetValue.IsNil() { + targetElemValue = reflect.Zero(targetValue.Type().Elem()) + } else { + targetElemValue = targetValue.Elem() + } + targetElemType := targetElemValue.Type() + + if targetElemType.Kind() == reflect.Struct { + exportedFields := getExportedFieldNameValues(targetElemValue) + if len(exportedFields) == 0 { + return nil, nil, false + } + + w := ptrStructNameWrapper{ + s: target, + exportedFields: exportedFields, + } + return &wrapAnyPtrStructFieldNameScanPlan{}, &w, true + } + + return nil, nil, false +} + +type wrapAnyPtrStructFieldNameScanPlan struct { + next ScanPlan +} + +func (plan *wrapAnyPtrStructFieldNameScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapAnyPtrStructFieldNameScanPlan) Scan(src []byte, target any) error { + w := ptrStructNameWrapper{ + s: target, + exportedFields: getExportedFieldNameValues(reflect.ValueOf(target).Elem()), + } + + return plan.next.Scan(src, &w) +} + type wrapAnyPtrStructScanPlan struct { next ScanPlan } @@ -1842,6 +1888,27 @@ func getExportedFieldValues(structValue reflect.Value) []reflect.Value { return exportedFields } +func getExportedFieldNameValues(structValue reflect.Value) map[string]reflect.Value { + structType := structValue.Type() + exportedFields := make(map[string]reflect.Value, structValue.NumField()) + for i := 0; i < structType.NumField(); i++ { + sf := structType.Field(i) + if sf.IsExported() { + name := sf.Name + value, ok := sf.Tag.Lookup("db") + if ok { + if value == "-" { + continue + } + name = value + } + exportedFields[name] = structValue.Field(i) + } + } + + return exportedFields +} + func TryWrapSliceEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { if _, ok := value.(driver.Valuer); ok { return nil, nil, false