mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2024-12-29 10:36:31 +00:00
852 lines
24 KiB
Go
852 lines
24 KiB
Go
package pgx
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
)
|
|
|
|
// Rows is the result set returned from *Conn.Query. Rows must be closed before
|
|
// the *Conn can be used again. Rows are closed by explicitly calling Close(),
|
|
// calling Next() until it returns false, or when a fatal error occurs.
|
|
//
|
|
// Once a Rows is closed the only methods that may be called are Close(), Err(),
|
|
// and CommandTag().
|
|
//
|
|
// Rows is an interface instead of a struct to allow tests to mock Query. However,
|
|
// adding a method to an interface is technically a breaking change. Because of this
|
|
// the Rows interface is partially excluded from semantic version requirements.
|
|
// Methods will not be removed or changed, but new methods may be added.
|
|
type Rows interface {
|
|
// Close closes the rows, making the connection ready for use again. It is safe
|
|
// to call Close after rows is already closed.
|
|
Close()
|
|
|
|
// Err returns any error that occurred while reading. Err must only be called after the Rows is closed (either by
|
|
// calling Close or by Next returning false). If it is called early it may return nil even if there was an error
|
|
// executing the query.
|
|
Err() error
|
|
|
|
// CommandTag returns the command tag from this query. It is only available after Rows is closed.
|
|
CommandTag() pgconn.CommandTag
|
|
|
|
// FieldDescriptions returns the field descriptions of the columns. It may return nil. In particular this can occur
|
|
// when there was an error executing the query.
|
|
FieldDescriptions() []pgconn.FieldDescription
|
|
|
|
// Next prepares the next row for reading. It returns true if there is another
|
|
// row and false if no more rows are available or a fatal error has occurred.
|
|
// It automatically closes rows when all rows are read.
|
|
//
|
|
// Callers should check rows.Err() after rows.Next() returns false to detect
|
|
// whether result-set reading ended prematurely due to an error. See
|
|
// Conn.Query for details.
|
|
//
|
|
// For simpler error handling, consider using the higher-level pgx v5
|
|
// CollectRows() and ForEachRow() helpers instead.
|
|
Next() bool
|
|
|
|
// Scan reads the values from the current row into dest values positionally.
|
|
// dest can include pointers to core types, values implementing the Scanner
|
|
// interface, and nil. nil will skip the value entirely. It is an error to
|
|
// call Scan without first calling Next() and checking that it returned true.
|
|
Scan(dest ...any) error
|
|
|
|
// Values returns the decoded row values. As with Scan(), it is an error to
|
|
// call Values without first calling Next() and checking that it returned
|
|
// true.
|
|
Values() ([]any, error)
|
|
|
|
// RawValues returns the unparsed bytes of the row values. The returned data is only valid until the next Next
|
|
// call or the Rows is closed.
|
|
RawValues() [][]byte
|
|
|
|
// Conn returns the underlying *Conn on which the query was executed. This may return nil if Rows did not come from a
|
|
// *Conn (e.g. if it was created by RowsFromResultReader)
|
|
Conn() *Conn
|
|
}
|
|
|
|
// Row is a convenience wrapper over Rows that is returned by QueryRow.
|
|
//
|
|
// Row is an interface instead of a struct to allow tests to mock QueryRow. However,
|
|
// adding a method to an interface is technically a breaking change. Because of this
|
|
// the Row interface is partially excluded from semantic version requirements.
|
|
// Methods will not be removed or changed, but new methods may be added.
|
|
type Row interface {
|
|
// Scan works the same as Rows. with the following exceptions. If no
|
|
// rows were found it returns ErrNoRows. If multiple rows are returned it
|
|
// ignores all but the first.
|
|
Scan(dest ...any) error
|
|
}
|
|
|
|
// RowScanner scans an entire row at a time into the RowScanner.
|
|
type RowScanner interface {
|
|
// ScanRows scans the row.
|
|
ScanRow(rows Rows) error
|
|
}
|
|
|
|
// connRow implements the Row interface for Conn.QueryRow.
|
|
type connRow baseRows
|
|
|
|
func (r *connRow) Scan(dest ...any) (err error) {
|
|
rows := (*baseRows)(r)
|
|
|
|
if rows.Err() != nil {
|
|
return rows.Err()
|
|
}
|
|
|
|
for _, d := range dest {
|
|
if _, ok := d.(*pgtype.DriverBytes); ok {
|
|
rows.Close()
|
|
return fmt.Errorf("cannot scan into *pgtype.DriverBytes from QueryRow")
|
|
}
|
|
}
|
|
|
|
if !rows.Next() {
|
|
if rows.Err() == nil {
|
|
return ErrNoRows
|
|
}
|
|
return rows.Err()
|
|
}
|
|
|
|
rows.Scan(dest...)
|
|
rows.Close()
|
|
return rows.Err()
|
|
}
|
|
|
|
// baseRows implements the Rows interface for Conn.Query.
|
|
type baseRows struct {
|
|
typeMap *pgtype.Map
|
|
resultReader *pgconn.ResultReader
|
|
|
|
values [][]byte
|
|
|
|
commandTag pgconn.CommandTag
|
|
err error
|
|
closed bool
|
|
|
|
scanPlans []pgtype.ScanPlan
|
|
scanTypes []reflect.Type
|
|
|
|
conn *Conn
|
|
multiResultReader *pgconn.MultiResultReader
|
|
|
|
queryTracer QueryTracer
|
|
batchTracer BatchTracer
|
|
ctx context.Context
|
|
startTime time.Time
|
|
sql string
|
|
args []any
|
|
rowCount int
|
|
}
|
|
|
|
func (rows *baseRows) FieldDescriptions() []pgconn.FieldDescription {
|
|
return rows.resultReader.FieldDescriptions()
|
|
}
|
|
|
|
func (rows *baseRows) Close() {
|
|
if rows.closed {
|
|
return
|
|
}
|
|
|
|
rows.closed = true
|
|
|
|
if rows.resultReader != nil {
|
|
var closeErr error
|
|
rows.commandTag, closeErr = rows.resultReader.Close()
|
|
if rows.err == nil {
|
|
rows.err = closeErr
|
|
}
|
|
}
|
|
|
|
if rows.multiResultReader != nil {
|
|
closeErr := rows.multiResultReader.Close()
|
|
if rows.err == nil {
|
|
rows.err = closeErr
|
|
}
|
|
}
|
|
|
|
if rows.err != nil && rows.conn != nil && rows.sql != "" {
|
|
if sc := rows.conn.statementCache; sc != nil {
|
|
sc.Invalidate(rows.sql)
|
|
}
|
|
|
|
if sc := rows.conn.descriptionCache; sc != nil {
|
|
sc.Invalidate(rows.sql)
|
|
}
|
|
}
|
|
|
|
if rows.batchTracer != nil {
|
|
rows.batchTracer.TraceBatchQuery(rows.ctx, rows.conn, TraceBatchQueryData{SQL: rows.sql, Args: rows.args, CommandTag: rows.commandTag, Err: rows.err})
|
|
} else if rows.queryTracer != nil {
|
|
rows.queryTracer.TraceQueryEnd(rows.ctx, rows.conn, TraceQueryEndData{rows.commandTag, rows.err})
|
|
}
|
|
}
|
|
|
|
func (rows *baseRows) CommandTag() pgconn.CommandTag {
|
|
return rows.commandTag
|
|
}
|
|
|
|
func (rows *baseRows) Err() error {
|
|
return rows.err
|
|
}
|
|
|
|
// fatal signals an error occurred after the query was sent to the server. It
|
|
// closes the rows automatically.
|
|
func (rows *baseRows) fatal(err error) {
|
|
if rows.err != nil {
|
|
return
|
|
}
|
|
|
|
rows.err = err
|
|
rows.Close()
|
|
}
|
|
|
|
func (rows *baseRows) Next() bool {
|
|
if rows.closed {
|
|
return false
|
|
}
|
|
|
|
if rows.resultReader.NextRow() {
|
|
rows.rowCount++
|
|
rows.values = rows.resultReader.Values()
|
|
return true
|
|
} else {
|
|
rows.Close()
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (rows *baseRows) Scan(dest ...any) error {
|
|
m := rows.typeMap
|
|
fieldDescriptions := rows.FieldDescriptions()
|
|
values := rows.values
|
|
|
|
if len(fieldDescriptions) != len(values) {
|
|
err := fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values))
|
|
rows.fatal(err)
|
|
return err
|
|
}
|
|
|
|
if len(dest) == 1 {
|
|
if rc, ok := dest[0].(RowScanner); ok {
|
|
err := rc.ScanRow(rows)
|
|
if err != nil {
|
|
rows.fatal(err)
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
|
|
if len(fieldDescriptions) != len(dest) {
|
|
err := fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest))
|
|
rows.fatal(err)
|
|
return err
|
|
}
|
|
|
|
if rows.scanPlans == nil {
|
|
rows.scanPlans = make([]pgtype.ScanPlan, len(values))
|
|
rows.scanTypes = make([]reflect.Type, len(values))
|
|
for i := range dest {
|
|
rows.scanPlans[i] = m.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i])
|
|
rows.scanTypes[i] = reflect.TypeOf(dest[i])
|
|
}
|
|
}
|
|
|
|
for i, dst := range dest {
|
|
if dst == nil {
|
|
continue
|
|
}
|
|
|
|
if rows.scanTypes[i] != reflect.TypeOf(dst) {
|
|
rows.scanPlans[i] = m.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i])
|
|
rows.scanTypes[i] = reflect.TypeOf(dest[i])
|
|
}
|
|
|
|
err := rows.scanPlans[i].Scan(values[i], dst)
|
|
if err != nil {
|
|
err = ScanArgError{ColumnIndex: i, Err: err}
|
|
rows.fatal(err)
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (rows *baseRows) Values() ([]any, error) {
|
|
if rows.closed {
|
|
return nil, errors.New("rows is closed")
|
|
}
|
|
|
|
values := make([]any, 0, len(rows.FieldDescriptions()))
|
|
|
|
for i := range rows.FieldDescriptions() {
|
|
buf := rows.values[i]
|
|
fd := &rows.FieldDescriptions()[i]
|
|
|
|
if buf == nil {
|
|
values = append(values, nil)
|
|
continue
|
|
}
|
|
|
|
if dt, ok := rows.typeMap.TypeForOID(fd.DataTypeOID); ok {
|
|
value, err := dt.Codec.DecodeValue(rows.typeMap, fd.DataTypeOID, fd.Format, buf)
|
|
if err != nil {
|
|
rows.fatal(err)
|
|
}
|
|
values = append(values, value)
|
|
} else {
|
|
switch fd.Format {
|
|
case TextFormatCode:
|
|
values = append(values, string(buf))
|
|
case BinaryFormatCode:
|
|
newBuf := make([]byte, len(buf))
|
|
copy(newBuf, buf)
|
|
values = append(values, newBuf)
|
|
default:
|
|
rows.fatal(errors.New("unknown format code"))
|
|
}
|
|
}
|
|
|
|
if rows.Err() != nil {
|
|
return nil, rows.Err()
|
|
}
|
|
}
|
|
|
|
return values, rows.Err()
|
|
}
|
|
|
|
func (rows *baseRows) RawValues() [][]byte {
|
|
return rows.values
|
|
}
|
|
|
|
func (rows *baseRows) Conn() *Conn {
|
|
return rows.conn
|
|
}
|
|
|
|
type ScanArgError struct {
|
|
ColumnIndex int
|
|
Err error
|
|
}
|
|
|
|
func (e ScanArgError) Error() string {
|
|
return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err)
|
|
}
|
|
|
|
func (e ScanArgError) Unwrap() error {
|
|
return e.Err
|
|
}
|
|
|
|
// ScanRow decodes raw row data into dest. It can be used to scan rows read from the lower level pgconn interface.
|
|
//
|
|
// typeMap - OID to Go type mapping.
|
|
// fieldDescriptions - OID and format of values
|
|
// values - the raw data as returned from the PostgreSQL server
|
|
// dest - the destination that values will be decoded into
|
|
func ScanRow(typeMap *pgtype.Map, fieldDescriptions []pgconn.FieldDescription, values [][]byte, dest ...any) error {
|
|
if len(fieldDescriptions) != len(values) {
|
|
return fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values))
|
|
}
|
|
if len(fieldDescriptions) != len(dest) {
|
|
return fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest))
|
|
}
|
|
|
|
for i, d := range dest {
|
|
if d == nil {
|
|
continue
|
|
}
|
|
|
|
err := typeMap.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d)
|
|
if err != nil {
|
|
return ScanArgError{ColumnIndex: i, Err: err}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// RowsFromResultReader returns a Rows that will read from values resultReader and decode with typeMap. It can be used
|
|
// to read from the lower level pgconn interface.
|
|
func RowsFromResultReader(typeMap *pgtype.Map, resultReader *pgconn.ResultReader) Rows {
|
|
return &baseRows{
|
|
typeMap: typeMap,
|
|
resultReader: resultReader,
|
|
}
|
|
}
|
|
|
|
// ForEachRow iterates through rows. For each row it scans into the elements of scans and calls fn. If any row
|
|
// fails to scan or fn returns an error the query will be aborted and the error will be returned. Rows will be closed
|
|
// when ForEachRow returns.
|
|
func ForEachRow(rows Rows, scans []any, fn func() error) (pgconn.CommandTag, error) {
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
err := rows.Scan(scans...)
|
|
if err != nil {
|
|
return pgconn.CommandTag{}, err
|
|
}
|
|
|
|
err = fn()
|
|
if err != nil {
|
|
return pgconn.CommandTag{}, err
|
|
}
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return pgconn.CommandTag{}, err
|
|
}
|
|
|
|
return rows.CommandTag(), nil
|
|
}
|
|
|
|
// CollectableRow is the subset of Rows methods that a RowToFunc is allowed to call.
|
|
type CollectableRow interface {
|
|
FieldDescriptions() []pgconn.FieldDescription
|
|
Scan(dest ...any) error
|
|
Values() ([]any, error)
|
|
RawValues() [][]byte
|
|
}
|
|
|
|
// RowToFunc is a function that scans or otherwise converts row to a T.
|
|
type RowToFunc[T any] func(row CollectableRow) (T, error)
|
|
|
|
// AppendRows iterates through rows, calling fn for each row, and appending the results into a slice of T.
|
|
//
|
|
// This function closes the rows automatically on return.
|
|
func AppendRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T]) (S, error) {
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
value, err := fn(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
slice = append(slice, value)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return slice, nil
|
|
}
|
|
|
|
// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T.
|
|
//
|
|
// This function closes the rows automatically on return.
|
|
func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
|
|
return AppendRows([]T{}, rows, fn)
|
|
}
|
|
|
|
// CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true.
|
|
// CollectOneRow is to CollectRows as QueryRow is to Query.
|
|
//
|
|
// This function closes the rows automatically on return.
|
|
func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
|
|
defer rows.Close()
|
|
|
|
var value T
|
|
var err error
|
|
|
|
if !rows.Next() {
|
|
if err = rows.Err(); err != nil {
|
|
return value, err
|
|
}
|
|
return value, ErrNoRows
|
|
}
|
|
|
|
value, err = fn(rows)
|
|
if err != nil {
|
|
return value, err
|
|
}
|
|
|
|
rows.Close()
|
|
return value, rows.Err()
|
|
}
|
|
|
|
// CollectExactlyOneRow calls fn for the first row in rows and returns the result.
|
|
// - If no rows are found returns an error where errors.Is(ErrNoRows) is true.
|
|
// - If more than 1 row is found returns an error where errors.Is(ErrTooManyRows) is true.
|
|
//
|
|
// This function closes the rows automatically on return.
|
|
func CollectExactlyOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
|
|
defer rows.Close()
|
|
|
|
var (
|
|
err error
|
|
value T
|
|
)
|
|
|
|
if !rows.Next() {
|
|
if err = rows.Err(); err != nil {
|
|
return value, err
|
|
}
|
|
|
|
return value, ErrNoRows
|
|
}
|
|
|
|
value, err = fn(rows)
|
|
if err != nil {
|
|
return value, err
|
|
}
|
|
|
|
if rows.Next() {
|
|
var zero T
|
|
|
|
return zero, ErrTooManyRows
|
|
}
|
|
|
|
return value, rows.Err()
|
|
}
|
|
|
|
// RowTo returns a T scanned from row.
|
|
func RowTo[T any](row CollectableRow) (T, error) {
|
|
var value T
|
|
err := row.Scan(&value)
|
|
return value, err
|
|
}
|
|
|
|
// RowTo returns a the address of a T scanned from row.
|
|
func RowToAddrOf[T any](row CollectableRow) (*T, error) {
|
|
var value T
|
|
err := row.Scan(&value)
|
|
return &value, err
|
|
}
|
|
|
|
// RowToMap returns a map scanned from row.
|
|
func RowToMap(row CollectableRow) (map[string]any, error) {
|
|
var value map[string]any
|
|
err := row.Scan((*mapRowScanner)(&value))
|
|
return value, err
|
|
}
|
|
|
|
type mapRowScanner map[string]any
|
|
|
|
func (rs *mapRowScanner) ScanRow(rows Rows) error {
|
|
values, err := rows.Values()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
*rs = make(mapRowScanner, len(values))
|
|
|
|
for i := range values {
|
|
(*rs)[string(rows.FieldDescriptions()[i].Name)] = values[i]
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number a public fields as row
|
|
// has fields. The row and T fields will be matched by position. If the "db" struct tag is "-" then the field will be
|
|
// ignored.
|
|
func RowToStructByPos[T any](row CollectableRow) (T, error) {
|
|
var value T
|
|
err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row)
|
|
return value, err
|
|
}
|
|
|
|
// RowToAddrOfStructByPos returns the address of a T scanned from row. T must be a struct. T must have the same number a
|
|
// public fields as row has fields. The row and T fields will be matched by position. If the "db" struct tag is "-" then
|
|
// the field will be ignored.
|
|
func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) {
|
|
var value T
|
|
err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row)
|
|
return &value, err
|
|
}
|
|
|
|
type positionalStructRowScanner struct {
|
|
ptrToStruct any
|
|
}
|
|
|
|
func (rs *positionalStructRowScanner) ScanRow(rows CollectableRow) error {
|
|
typ := reflect.TypeOf(rs.ptrToStruct).Elem()
|
|
fields := lookupStructFields(typ)
|
|
if len(rows.RawValues()) > len(fields) {
|
|
return fmt.Errorf(
|
|
"got %d values, but dst struct has only %d fields",
|
|
len(rows.RawValues()),
|
|
len(fields),
|
|
)
|
|
}
|
|
scanTargets := setupStructScanTargets(rs.ptrToStruct, fields)
|
|
return rows.Scan(scanTargets...)
|
|
}
|
|
|
|
// Map from reflect.Type -> []structRowField
|
|
var positionalStructFieldMap sync.Map
|
|
|
|
func lookupStructFields(t reflect.Type) []structRowField {
|
|
if cached, ok := positionalStructFieldMap.Load(t); ok {
|
|
return cached.([]structRowField)
|
|
}
|
|
|
|
fieldStack := make([]int, 0, 1)
|
|
fields := computeStructFields(t, make([]structRowField, 0, t.NumField()), &fieldStack)
|
|
fieldsIface, _ := positionalStructFieldMap.LoadOrStore(t, fields)
|
|
return fieldsIface.([]structRowField)
|
|
}
|
|
|
|
func computeStructFields(
|
|
t reflect.Type,
|
|
fields []structRowField,
|
|
fieldStack *[]int,
|
|
) []structRowField {
|
|
tail := len(*fieldStack)
|
|
*fieldStack = append(*fieldStack, 0)
|
|
for i := 0; i < t.NumField(); i++ {
|
|
sf := t.Field(i)
|
|
(*fieldStack)[tail] = i
|
|
// Handle anonymous struct embedding, but do not try to handle embedded pointers.
|
|
if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
|
|
fields = computeStructFields(sf.Type, fields, fieldStack)
|
|
} else if sf.PkgPath == "" {
|
|
dbTag, _ := sf.Tag.Lookup(structTagKey)
|
|
if dbTag == "-" {
|
|
// Field is ignored, skip it.
|
|
continue
|
|
}
|
|
fields = append(fields, structRowField{
|
|
path: append([]int(nil), *fieldStack...),
|
|
})
|
|
}
|
|
}
|
|
*fieldStack = (*fieldStack)[:tail]
|
|
return fields
|
|
}
|
|
|
|
// RowToStructByName returns a T scanned from row. T must be a struct. T must have the same number of named public
|
|
// fields as row has fields. The row and T fields will be matched by name. The match is case-insensitive. The database
|
|
// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored.
|
|
func RowToStructByName[T any](row CollectableRow) (T, error) {
|
|
var value T
|
|
err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row)
|
|
return value, err
|
|
}
|
|
|
|
// RowToAddrOfStructByName returns the address of a T scanned from row. T must be a struct. T must have the same number
|
|
// of named public fields as row has fields. The row and T fields will be matched by name. The match is
|
|
// case-insensitive. The database column name can be overridden with a "db" struct tag. If the "db" struct tag is "-"
|
|
// then the field will be ignored.
|
|
func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) {
|
|
var value T
|
|
err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row)
|
|
return &value, err
|
|
}
|
|
|
|
// RowToStructByNameLax returns a T scanned from row. T must be a struct. T must have greater than or equal number of named public
|
|
// fields as row has fields. The row and T fields will be matched by name. The match is case-insensitive. The database
|
|
// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored.
|
|
func RowToStructByNameLax[T any](row CollectableRow) (T, error) {
|
|
var value T
|
|
err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row)
|
|
return value, err
|
|
}
|
|
|
|
// RowToAddrOfStructByNameLax returns the address of a T scanned from row. T must be a struct. T must have greater than or
|
|
// equal number of named public fields as row has fields. The row and T fields will be matched by name. The match is
|
|
// case-insensitive. The database column name can be overridden with a "db" struct tag. If the "db" struct tag is "-"
|
|
// then the field will be ignored.
|
|
func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) {
|
|
var value T
|
|
err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row)
|
|
return &value, err
|
|
}
|
|
|
|
type namedStructRowScanner struct {
|
|
ptrToStruct any
|
|
lax bool
|
|
}
|
|
|
|
func (rs *namedStructRowScanner) ScanRow(rows CollectableRow) error {
|
|
typ := reflect.TypeOf(rs.ptrToStruct).Elem()
|
|
fldDescs := rows.FieldDescriptions()
|
|
namedStructFields, err := lookupNamedStructFields(typ, fldDescs)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !rs.lax && namedStructFields.missingField != "" {
|
|
return fmt.Errorf("cannot find field %s in returned row", namedStructFields.missingField)
|
|
}
|
|
fields := namedStructFields.fields
|
|
scanTargets := setupStructScanTargets(rs.ptrToStruct, fields)
|
|
return rows.Scan(scanTargets...)
|
|
}
|
|
|
|
// Map from namedStructFieldMap -> *namedStructFields
|
|
var namedStructFieldMap sync.Map
|
|
|
|
type namedStructFieldsKey struct {
|
|
t reflect.Type
|
|
colNames string
|
|
}
|
|
|
|
type namedStructFields struct {
|
|
fields []structRowField
|
|
// missingField is the first field from the struct without a corresponding row field.
|
|
// This is used to construct the correct error message for non-lax queries.
|
|
missingField string
|
|
}
|
|
|
|
func lookupNamedStructFields(
|
|
t reflect.Type,
|
|
fldDescs []pgconn.FieldDescription,
|
|
) (*namedStructFields, error) {
|
|
key := namedStructFieldsKey{
|
|
t: t,
|
|
colNames: joinFieldNames(fldDescs),
|
|
}
|
|
if cached, ok := namedStructFieldMap.Load(key); ok {
|
|
return cached.(*namedStructFields), nil
|
|
}
|
|
|
|
// We could probably do two-levels of caching, where we compute the key -> fields mapping
|
|
// for a type only once, cache it by type, then use that to compute the column -> fields
|
|
// mapping for a given set of columns.
|
|
fieldStack := make([]int, 0, 1)
|
|
fields, missingField := computeNamedStructFields(
|
|
fldDescs,
|
|
t,
|
|
make([]structRowField, len(fldDescs)),
|
|
&fieldStack,
|
|
)
|
|
for i, f := range fields {
|
|
if f.path == nil {
|
|
return nil, fmt.Errorf(
|
|
"struct doesn't have corresponding row field %s",
|
|
fldDescs[i].Name,
|
|
)
|
|
}
|
|
}
|
|
|
|
fieldsIface, _ := namedStructFieldMap.LoadOrStore(
|
|
key,
|
|
&namedStructFields{fields: fields, missingField: missingField},
|
|
)
|
|
return fieldsIface.(*namedStructFields), nil
|
|
}
|
|
|
|
func joinFieldNames(fldDescs []pgconn.FieldDescription) string {
|
|
switch len(fldDescs) {
|
|
case 0:
|
|
return ""
|
|
case 1:
|
|
return fldDescs[0].Name
|
|
}
|
|
|
|
totalSize := len(fldDescs) - 1 // Space for separator bytes.
|
|
for _, d := range fldDescs {
|
|
totalSize += len(d.Name)
|
|
}
|
|
var b strings.Builder
|
|
b.Grow(totalSize)
|
|
b.WriteString(fldDescs[0].Name)
|
|
for _, d := range fldDescs[1:] {
|
|
b.WriteByte(0) // Join with NUL byte as it's (presumably) not a valid column character.
|
|
b.WriteString(d.Name)
|
|
}
|
|
return b.String()
|
|
}
|
|
|
|
func computeNamedStructFields(
|
|
fldDescs []pgconn.FieldDescription,
|
|
t reflect.Type,
|
|
fields []structRowField,
|
|
fieldStack *[]int,
|
|
) ([]structRowField, string) {
|
|
var missingField string
|
|
tail := len(*fieldStack)
|
|
*fieldStack = append(*fieldStack, 0)
|
|
for i := 0; i < t.NumField(); i++ {
|
|
sf := t.Field(i)
|
|
(*fieldStack)[tail] = i
|
|
if sf.PkgPath != "" && !sf.Anonymous {
|
|
// Field is unexported, skip it.
|
|
continue
|
|
}
|
|
// Handle anonymous struct embedding, but do not try to handle embedded pointers.
|
|
if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
|
|
var missingSubField string
|
|
fields, missingSubField = computeNamedStructFields(
|
|
fldDescs,
|
|
sf.Type,
|
|
fields,
|
|
fieldStack,
|
|
)
|
|
if missingField == "" {
|
|
missingField = missingSubField
|
|
}
|
|
} else {
|
|
dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey)
|
|
if dbTagPresent {
|
|
dbTag, _, _ = strings.Cut(dbTag, ",")
|
|
}
|
|
if dbTag == "-" {
|
|
// Field is ignored, skip it.
|
|
continue
|
|
}
|
|
colName := dbTag
|
|
if !dbTagPresent {
|
|
colName = sf.Name
|
|
}
|
|
fpos := fieldPosByName(fldDescs, colName)
|
|
if fpos == -1 {
|
|
if missingField == "" {
|
|
missingField = colName
|
|
}
|
|
continue
|
|
}
|
|
fields[fpos] = structRowField{
|
|
path: append([]int(nil), *fieldStack...),
|
|
}
|
|
}
|
|
}
|
|
*fieldStack = (*fieldStack)[:tail]
|
|
|
|
return fields, missingField
|
|
}
|
|
|
|
const structTagKey = "db"
|
|
|
|
func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) {
|
|
i = -1
|
|
for i, desc := range fldDescs {
|
|
|
|
// Snake case support.
|
|
field = strings.ReplaceAll(field, "_", "")
|
|
descName := strings.ReplaceAll(desc.Name, "_", "")
|
|
|
|
if strings.EqualFold(descName, field) {
|
|
return i
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// structRowField describes a field of a struct.
|
|
//
|
|
// TODO: It would be a bit more efficient to track the path using the pointer
|
|
// offset within the (outermost) struct and use unsafe.Pointer arithmetic to
|
|
// construct references when scanning rows. However, it's not clear it's worth
|
|
// using unsafe for this.
|
|
type structRowField struct {
|
|
path []int
|
|
}
|
|
|
|
func setupStructScanTargets(receiver any, fields []structRowField) []any {
|
|
scanTargets := make([]any, len(fields))
|
|
v := reflect.ValueOf(receiver).Elem()
|
|
for i, f := range fields {
|
|
scanTargets[i] = v.FieldByIndex(f.path).Addr().Interface()
|
|
}
|
|
return scanTargets
|
|
}
|