mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-01-11 09:00:12 +00:00
346 lines
7 KiB
Go
346 lines
7 KiB
Go
|
package bun
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"database/sql"
|
||
|
"fmt"
|
||
|
"reflect"
|
||
|
"strings"
|
||
|
|
||
|
"github.com/uptrace/bun/dialect"
|
||
|
"github.com/uptrace/bun/schema"
|
||
|
)
|
||
|
|
||
|
type structTableModel struct {
|
||
|
db *DB
|
||
|
table *schema.Table
|
||
|
|
||
|
rel *schema.Relation
|
||
|
joins []join
|
||
|
|
||
|
dest interface{}
|
||
|
root reflect.Value
|
||
|
index []int
|
||
|
|
||
|
strct reflect.Value
|
||
|
structInited bool
|
||
|
structInitErr error
|
||
|
|
||
|
columns []string
|
||
|
scanIndex int
|
||
|
}
|
||
|
|
||
|
var _ tableModel = (*structTableModel)(nil)
|
||
|
|
||
|
func newStructTableModel(db *DB, dest interface{}, table *schema.Table) *structTableModel {
|
||
|
return &structTableModel{
|
||
|
db: db,
|
||
|
table: table,
|
||
|
dest: dest,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func newStructTableModelValue(db *DB, dest interface{}, v reflect.Value) *structTableModel {
|
||
|
return &structTableModel{
|
||
|
db: db,
|
||
|
table: db.Table(v.Type()),
|
||
|
dest: dest,
|
||
|
root: v,
|
||
|
strct: v,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (m *structTableModel) Value() interface{} {
|
||
|
return m.dest
|
||
|
}
|
||
|
|
||
|
func (m *structTableModel) Table() *schema.Table {
|
||
|
return m.table
|
||
|
}
|
||
|
|
||
|
func (m *structTableModel) Relation() *schema.Relation {
|
||
|
return m.rel
|
||
|
}
|
||
|
|
||
|
func (m *structTableModel) Root() reflect.Value {
|
||
|
return m.root
|
||
|
}
|
||
|
|
||
|
func (m *structTableModel) Index() []int {
|
||
|
return m.index
|
||
|
}
|
||
|
|
||
|
func (m *structTableModel) ParentIndex() []int {
|
||
|
return m.index[:len(m.index)-len(m.rel.Field.Index)]
|
||
|
}
|
||
|
|
||
|
func (m *structTableModel) Mount(host reflect.Value) {
|
||
|
m.strct = host.FieldByIndex(m.rel.Field.Index)
|
||
|
m.structInited = false
|
||
|
}
|
||
|
|
||
|
func (m *structTableModel) initStruct() error {
|
||
|
if m.structInited {
|
||
|
return m.structInitErr
|
||
|
}
|
||
|
m.structInited = true
|
||
|
|
||
|
switch m.strct.Kind() {
|
||
|
case reflect.Invalid:
|
||
|
m.structInitErr = errNilModel
|
||
|
return m.structInitErr
|
||
|
case reflect.Interface:
|
||
|
m.strct = m.strct.Elem()
|
||
|
}
|
||
|
|
||
|
if m.strct.Kind() == reflect.Ptr {
|
||
|
if m.strct.IsNil() {
|
||
|
m.strct.Set(reflect.New(m.strct.Type().Elem()))
|
||
|
m.strct = m.strct.Elem()
|
||
|
} else {
|
||
|
m.strct = m.strct.Elem()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
m.mountJoins()
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (m *structTableModel) mountJoins() {
|
||
|
for i := range m.joins {
|
||
|
j := &m.joins[i]
|
||
|
switch j.Relation.Type {
|
||
|
case schema.HasOneRelation, schema.BelongsToRelation:
|
||
|
j.JoinModel.Mount(m.strct)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var _ schema.BeforeScanHook = (*structTableModel)(nil)
|
||
|
|
||
|
func (m *structTableModel) BeforeScan(ctx context.Context) error {
|
||
|
if !m.table.HasBeforeScanHook() {
|
||
|
return nil
|
||
|
}
|
||
|
return callBeforeScanHook(ctx, m.strct.Addr())
|
||
|
}
|
||
|
|
||
|
var _ schema.AfterScanHook = (*structTableModel)(nil)
|
||
|
|
||
|
func (m *structTableModel) AfterScan(ctx context.Context) error {
|
||
|
if !m.table.HasAfterScanHook() || !m.structInited {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
var firstErr error
|
||
|
|
||
|
if err := callAfterScanHook(ctx, m.strct.Addr()); err != nil && firstErr == nil {
|
||
|
firstErr = err
|
||
|
}
|
||
|
|
||
|
for _, j := range m.joins {
|
||
|
switch j.Relation.Type {
|
||
|
case schema.HasOneRelation, schema.BelongsToRelation:
|
||
|
if err := j.JoinModel.AfterScan(ctx); err != nil && firstErr == nil {
|
||
|
firstErr = err
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return firstErr
|
||
|
}
|
||
|
|
||
|
func (m *structTableModel) GetJoin(name string) *join {
|
||
|
for i := range m.joins {
|
||
|
j := &m.joins[i]
|
||
|
if j.Relation.Field.Name == name || j.Relation.Field.GoName == name {
|
||
|
return j
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (m *structTableModel) GetJoins() []join {
|
||
|
return m.joins
|
||
|
}
|
||
|
|
||
|
func (m *structTableModel) AddJoin(j join) *join {
|
||
|
m.joins = append(m.joins, j)
|
||
|
return &m.joins[len(m.joins)-1]
|
||
|
}
|
||
|
|
||
|
func (m *structTableModel) Join(name string, apply func(*SelectQuery) *SelectQuery) *join {
|
||
|
return m.join(m.strct, name, apply)
|
||
|
}
|
||
|
|
||
|
func (m *structTableModel) join(
|
||
|
bind reflect.Value, name string, apply func(*SelectQuery) *SelectQuery,
|
||
|
) *join {
|
||
|
path := strings.Split(name, ".")
|
||
|
index := make([]int, 0, len(path))
|
||
|
|
||
|
currJoin := join{
|
||
|
BaseModel: m,
|
||
|
JoinModel: m,
|
||
|
}
|
||
|
var lastJoin *join
|
||
|
|
||
|
for _, name := range path {
|
||
|
relation, ok := currJoin.JoinModel.Table().Relations[name]
|
||
|
if !ok {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
currJoin.Relation = relation
|
||
|
index = append(index, relation.Field.Index...)
|
||
|
|
||
|
if j := currJoin.JoinModel.GetJoin(name); j != nil {
|
||
|
currJoin.BaseModel = j.BaseModel
|
||
|
currJoin.JoinModel = j.JoinModel
|
||
|
|
||
|
lastJoin = j
|
||
|
} else {
|
||
|
model, err := newTableModelIndex(m.db, m.table, bind, index, relation)
|
||
|
if err != nil {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
currJoin.Parent = lastJoin
|
||
|
currJoin.BaseModel = currJoin.JoinModel
|
||
|
currJoin.JoinModel = model
|
||
|
|
||
|
lastJoin = currJoin.BaseModel.AddJoin(currJoin)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// No joins with such name.
|
||
|
if lastJoin == nil {
|
||
|
return nil
|
||
|
}
|
||
|
if apply != nil {
|
||
|
lastJoin.ApplyQueryFunc = apply
|
||
|
}
|
||
|
|
||
|
return lastJoin
|
||
|
}
|
||
|
|
||
|
func (m *structTableModel) updateSoftDeleteField() error {
|
||
|
fv := m.table.SoftDeleteField.Value(m.strct)
|
||
|
return m.table.UpdateSoftDeleteField(fv)
|
||
|
}
|
||
|
|
||
|
func (m *structTableModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) {
|
||
|
if !rows.Next() {
|
||
|
return 0, rows.Err()
|
||
|
}
|
||
|
|
||
|
if err := m.ScanRow(ctx, rows); err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
|
||
|
// For inserts, SQLite3 can return a row like it was inserted sucessfully and then
|
||
|
// an actual error for the next row. See issues/100.
|
||
|
if m.db.dialect.Name() == dialect.SQLite {
|
||
|
_ = rows.Next()
|
||
|
if err := rows.Err(); err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return 1, nil
|
||
|
}
|
||
|
|
||
|
func (m *structTableModel) ScanRow(ctx context.Context, rows *sql.Rows) error {
|
||
|
columns, err := rows.Columns()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
m.columns = columns
|
||
|
dest := makeDest(m, len(columns))
|
||
|
|
||
|
return m.scanRow(ctx, rows, dest)
|
||
|
}
|
||
|
|
||
|
func (m *structTableModel) scanRow(ctx context.Context, rows *sql.Rows, dest []interface{}) error {
|
||
|
if err := m.BeforeScan(ctx); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
m.scanIndex = 0
|
||
|
if err := rows.Scan(dest...); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if err := m.AfterScan(ctx); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (m *structTableModel) Scan(src interface{}) error {
|
||
|
column := m.columns[m.scanIndex]
|
||
|
m.scanIndex++
|
||
|
|
||
|
return m.ScanColumn(unquote(column), src)
|
||
|
}
|
||
|
|
||
|
func (m *structTableModel) ScanColumn(column string, src interface{}) error {
|
||
|
if ok, err := m.scanColumn(column, src); ok {
|
||
|
return err
|
||
|
}
|
||
|
if column == "" || column[0] == '_' || m.db.flags.Has(discardUnknownColumns) {
|
||
|
return nil
|
||
|
}
|
||
|
return fmt.Errorf("bun: %s does not have column %q", m.table.TypeName, column)
|
||
|
}
|
||
|
|
||
|
func (m *structTableModel) scanColumn(column string, src interface{}) (bool, error) {
|
||
|
if src != nil {
|
||
|
if err := m.initStruct(); err != nil {
|
||
|
return true, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if field, ok := m.table.FieldMap[column]; ok {
|
||
|
return true, field.ScanValue(m.strct, src)
|
||
|
}
|
||
|
|
||
|
if joinName, column := splitColumn(column); joinName != "" {
|
||
|
if join := m.GetJoin(joinName); join != nil {
|
||
|
return true, join.JoinModel.ScanColumn(column, src)
|
||
|
}
|
||
|
if m.table.ModelName == joinName {
|
||
|
return true, m.ScanColumn(column, src)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return false, nil
|
||
|
}
|
||
|
|
||
|
func (m *structTableModel) AppendNamedArg(
|
||
|
fmter schema.Formatter, b []byte, name string,
|
||
|
) ([]byte, bool) {
|
||
|
return m.table.AppendNamedArg(fmter, b, name, m.strct)
|
||
|
}
|
||
|
|
||
|
// sqlite3 sometimes does not unquote columns.
|
||
|
func unquote(s string) string {
|
||
|
if s == "" {
|
||
|
return s
|
||
|
}
|
||
|
if s[0] == '"' && s[len(s)-1] == '"' {
|
||
|
return s[1 : len(s)-1]
|
||
|
}
|
||
|
return s
|
||
|
}
|
||
|
|
||
|
func splitColumn(s string) (string, string) {
|
||
|
if i := strings.Index(s, "__"); i >= 0 {
|
||
|
return s[:i], s[i+2:]
|
||
|
}
|
||
|
return "", s
|
||
|
}
|