package pgdialect import ( "context" "strings" "github.com/uptrace/bun" "github.com/uptrace/bun/migrate/sqlschema" orderedmap "github.com/wk8/go-ordered-map/v2" ) type ( Schema = sqlschema.BaseDatabase Table = sqlschema.BaseTable Column = sqlschema.BaseColumn ) func (d *Dialect) NewInspector(db *bun.DB, options ...sqlschema.InspectorOption) sqlschema.Inspector { return newInspector(db, options...) } type Inspector struct { sqlschema.InspectorConfig db *bun.DB } var _ sqlschema.Inspector = (*Inspector)(nil) func newInspector(db *bun.DB, options ...sqlschema.InspectorOption) *Inspector { i := &Inspector{db: db} sqlschema.ApplyInspectorOptions(&i.InspectorConfig, options...) return i } func (in *Inspector) Inspect(ctx context.Context) (sqlschema.Database, error) { dbSchema := Schema{ Tables: orderedmap.New[string, sqlschema.Table](), ForeignKeys: make(map[sqlschema.ForeignKey]string), } exclude := in.ExcludeTables if len(exclude) == 0 { // Avoid getting NOT IN (NULL) if bun.In() is called with an empty slice. exclude = []string{""} } var tables []*InformationSchemaTable if err := in.db.NewRaw(sqlInspectTables, in.SchemaName, bun.In(exclude)).Scan(ctx, &tables); err != nil { return dbSchema, err } var fks []*ForeignKey if err := in.db.NewRaw(sqlInspectForeignKeys, in.SchemaName, bun.In(exclude), bun.In(exclude)).Scan(ctx, &fks); err != nil { return dbSchema, err } dbSchema.ForeignKeys = make(map[sqlschema.ForeignKey]string, len(fks)) for _, table := range tables { var columns []*InformationSchemaColumn if err := in.db.NewRaw(sqlInspectColumnsQuery, table.Schema, table.Name).Scan(ctx, &columns); err != nil { return dbSchema, err } colDefs := orderedmap.New[string, sqlschema.Column]() uniqueGroups := make(map[string][]string) for _, c := range columns { def := c.Default if c.IsSerial || c.IsIdentity { def = "" } else if !c.IsDefaultLiteral { def = strings.ToLower(def) } colDefs.Set(c.Name, &Column{ Name: c.Name, SQLType: c.DataType, VarcharLen: c.VarcharLen, DefaultValue: def, IsNullable: c.IsNullable, IsAutoIncrement: c.IsSerial, IsIdentity: c.IsIdentity, }) for _, group := range c.UniqueGroups { uniqueGroups[group] = append(uniqueGroups[group], c.Name) } } var unique []sqlschema.Unique for name, columns := range uniqueGroups { unique = append(unique, sqlschema.Unique{ Name: name, Columns: sqlschema.NewColumns(columns...), }) } var pk *sqlschema.PrimaryKey if len(table.PrimaryKey.Columns) > 0 { pk = &sqlschema.PrimaryKey{ Name: table.PrimaryKey.ConstraintName, Columns: sqlschema.NewColumns(table.PrimaryKey.Columns...), } } dbSchema.Tables.Set(table.Name, &Table{ Schema: table.Schema, Name: table.Name, Columns: colDefs, PrimaryKey: pk, UniqueConstraints: unique, }) } for _, fk := range fks { dbSchema.ForeignKeys[sqlschema.ForeignKey{ From: sqlschema.NewColumnReference(fk.SourceTable, fk.SourceColumns...), To: sqlschema.NewColumnReference(fk.TargetTable, fk.TargetColumns...), }] = fk.ConstraintName } return dbSchema, nil } type InformationSchemaTable struct { Schema string `bun:"table_schema,pk"` Name string `bun:"table_name,pk"` PrimaryKey PrimaryKey `bun:"embed:primary_key_"` Columns []*InformationSchemaColumn `bun:"rel:has-many,join:table_schema=table_schema,join:table_name=table_name"` } type InformationSchemaColumn struct { Schema string `bun:"table_schema"` Table string `bun:"table_name"` Name string `bun:"column_name"` DataType string `bun:"data_type"` VarcharLen int `bun:"varchar_len"` IsArray bool `bun:"is_array"` ArrayDims int `bun:"array_dims"` Default string `bun:"default"` IsDefaultLiteral bool `bun:"default_is_literal_expr"` IsIdentity bool `bun:"is_identity"` IndentityType string `bun:"identity_type"` IsSerial bool `bun:"is_serial"` IsNullable bool `bun:"is_nullable"` UniqueGroups []string `bun:"unique_groups,array"` } type ForeignKey struct { ConstraintName string `bun:"constraint_name"` SourceSchema string `bun:"schema_name"` SourceTable string `bun:"table_name"` SourceColumns []string `bun:"columns,array"` TargetSchema string `bun:"target_schema"` TargetTable string `bun:"target_table"` TargetColumns []string `bun:"target_columns,array"` } type PrimaryKey struct { ConstraintName string `bun:"name"` Columns []string `bun:"columns,array"` } const ( // sqlInspectTables retrieves all user-defined tables in the selected schema. // Pass bun.In([]string{...}) to exclude tables from this inspection or bun.In([]string{''}) to include all results. sqlInspectTables = ` SELECT "t".table_schema, "t".table_name, pk.name AS primary_key_name, pk.columns AS primary_key_columns FROM information_schema.tables "t" LEFT JOIN ( SELECT i.indrelid, "idx".relname AS "name", ARRAY_AGG("a".attname) AS "columns" FROM pg_index i JOIN pg_attribute "a" ON "a".attrelid = i.indrelid AND "a".attnum = ANY("i".indkey) AND i.indisprimary JOIN pg_class "idx" ON i.indexrelid = "idx".oid GROUP BY 1, 2 ) pk ON ("t".table_schema || '.' || "t".table_name)::regclass = pk.indrelid WHERE table_type = 'BASE TABLE' AND "t".table_schema = ? AND "t".table_schema NOT LIKE 'pg_%' AND "table_name" NOT IN (?) ORDER BY "t".table_schema, "t".table_name ` // sqlInspectColumnsQuery retrieves column definitions for the specified table. // Unlike sqlInspectTables and sqlInspectSchema, it should be passed to bun.NewRaw // with additional args for table_schema and table_name. sqlInspectColumnsQuery = ` SELECT "c".table_schema, "c".table_name, "c".column_name, "c".data_type, "c".character_maximum_length::integer AS varchar_len, "c".data_type = 'ARRAY' AS is_array, COALESCE("c".array_dims, 0) AS array_dims, CASE WHEN "c".column_default ~ '^''.*''::.*$' THEN substring("c".column_default FROM '^''(.*)''::.*$') ELSE "c".column_default END AS "default", "c".column_default ~ '^''.*''::.*$' OR "c".column_default ~ '^[0-9\.]+$' AS default_is_literal_expr, "c".is_identity = 'YES' AS is_identity, "c".column_default = format('nextval(''%s_%s_seq''::regclass)', "c".table_name, "c".column_name) AS is_serial, COALESCE("c".identity_type, '') AS identity_type, "c".is_nullable = 'YES' AS is_nullable, "c"."unique_groups" AS unique_groups FROM ( SELECT "table_schema", "table_name", "column_name", "c".data_type, "c".character_maximum_length, "c".column_default, "c".is_identity, "c".is_nullable, att.array_dims, att.identity_type, att."unique_groups", att."constraint_type" FROM information_schema.columns "c" LEFT JOIN ( SELECT s.nspname AS "table_schema", "t".relname AS "table_name", "c".attname AS "column_name", "c".attndims AS array_dims, "c".attidentity AS identity_type, ARRAY_AGG(con.conname) FILTER (WHERE con.contype = 'u') AS "unique_groups", ARRAY_AGG(con.contype) AS "constraint_type" FROM ( SELECT conname, contype, connamespace, conrelid, conrelid AS attrelid, UNNEST(conkey) AS attnum FROM pg_constraint ) con LEFT JOIN pg_attribute "c" USING (attrelid, attnum) LEFT JOIN pg_namespace s ON s.oid = con.connamespace LEFT JOIN pg_class "t" ON "t".oid = con.conrelid GROUP BY 1, 2, 3, 4, 5 ) att USING ("table_schema", "table_name", "column_name") ) "c" WHERE "table_schema" = ? AND "table_name" = ? ORDER BY "table_schema", "table_name", "column_name" ` // sqlInspectForeignKeys get FK definitions for user-defined tables. // Pass bun.In([]string{...}) to exclude tables from this inspection or bun.In([]string{''}) to include all results. sqlInspectForeignKeys = ` WITH "schemas" AS ( SELECT oid, nspname FROM pg_namespace ), "tables" AS ( SELECT oid, relnamespace, relname, relkind FROM pg_class ), "columns" AS ( SELECT attrelid, attname, attnum FROM pg_attribute WHERE attisdropped = false ) SELECT DISTINCT co.conname AS "constraint_name", ss.nspname AS schema_name, s.relname AS "table_name", ARRAY_AGG(sc.attname) AS "columns", ts.nspname AS target_schema, "t".relname AS target_table, ARRAY_AGG(tc.attname) AS target_columns FROM pg_constraint co LEFT JOIN "tables" s ON s.oid = co.conrelid LEFT JOIN "schemas" ss ON ss.oid = s.relnamespace LEFT JOIN "columns" sc ON sc.attrelid = s.oid AND sc.attnum = ANY(co.conkey) LEFT JOIN "tables" t ON t.oid = co.confrelid LEFT JOIN "schemas" ts ON ts.oid = "t".relnamespace LEFT JOIN "columns" tc ON tc.attrelid = "t".oid AND tc.attnum = ANY(co.confkey) WHERE co.contype = 'f' AND co.conrelid IN (SELECT oid FROM pg_class WHERE relkind = 'r') AND ARRAY_POSITION(co.conkey, sc.attnum) = ARRAY_POSITION(co.confkey, tc.attnum) AND ss.nspname = ? AND s.relname NOT IN (?) AND "t".relname NOT IN (?) GROUP BY "constraint_name", "schema_name", "table_name", target_schema, target_table ` )