gotosocial/vendor/github.com/uptrace/bun/migrate/diff.go

412 lines
12 KiB
Go
Raw Normal View History

package migrate
import (
"github.com/uptrace/bun/migrate/sqlschema"
)
// changeset is a set of changes to the database schema definition.
type changeset struct {
operations []Operation
}
// Add new operations to the changeset.
func (c *changeset) Add(op ...Operation) {
c.operations = append(c.operations, op...)
}
// diff calculates the diff between the current database schema and the target state.
// The changeset is not sorted -- the caller should resolve dependencies before applying the changes.
func diff(got, want sqlschema.Database, opts ...diffOption) *changeset {
d := newDetector(got, want, opts...)
return d.detectChanges()
}
func (d *detector) detectChanges() *changeset {
currentTables := d.current.GetTables()
targetTables := d.target.GetTables()
RenameCreate:
for wantName, wantTable := range targetTables.FromOldest() {
// A table with this name exists in the database. We assume that schema objects won't
// be renamed to an already existing name, nor do we support such cases.
// Simply check if the table definition has changed.
if haveTable, ok := currentTables.Get(wantName); ok {
d.detectColumnChanges(haveTable, wantTable, true)
d.detectConstraintChanges(haveTable, wantTable)
continue
}
// Find all renamed tables. We assume that renamed tables have the same signature.
for haveName, haveTable := range currentTables.FromOldest() {
if _, exists := targetTables.Get(haveName); !exists && d.canRename(haveTable, wantTable) {
d.changes.Add(&RenameTableOp{
TableName: haveTable.GetName(),
NewName: wantName,
})
d.refMap.RenameTable(haveTable.GetName(), wantName)
// Find renamed columns, if any, and check if constraints (PK, UNIQUE) have been updated.
// We need not check wantTable any further.
d.detectColumnChanges(haveTable, wantTable, false)
d.detectConstraintChanges(haveTable, wantTable)
currentTables.Delete(haveName)
continue RenameCreate
}
}
// If wantTable does not exist in the database and was not renamed
// then we need to create this table in the database.
additional := wantTable.(*sqlschema.BunTable)
d.changes.Add(&CreateTableOp{
TableName: wantTable.GetName(),
Model: additional.Model,
})
}
// Drop any remaining "current" tables which do not have a model.
for name, table := range currentTables.FromOldest() {
if _, keep := targetTables.Get(name); !keep {
d.changes.Add(&DropTableOp{
TableName: table.GetName(),
})
}
}
targetFKs := d.target.GetForeignKeys()
currentFKs := d.refMap.Deref()
for fk := range targetFKs {
if _, ok := currentFKs[fk]; !ok {
d.changes.Add(&AddForeignKeyOp{
ForeignKey: fk,
ConstraintName: "", // leave empty to let each dialect apply their convention
})
}
}
for fk, name := range currentFKs {
if _, ok := targetFKs[fk]; !ok {
d.changes.Add(&DropForeignKeyOp{
ConstraintName: name,
ForeignKey: fk,
})
}
}
return &d.changes
}
// detechColumnChanges finds renamed columns and, if checkType == true, columns with changed type.
func (d *detector) detectColumnChanges(current, target sqlschema.Table, checkType bool) {
currentColumns := current.GetColumns()
targetColumns := target.GetColumns()
ChangeRename:
for tName, tCol := range targetColumns.FromOldest() {
// This column exists in the database, so it hasn't been renamed, dropped, or added.
// Still, we should not delete(columns, thisColumn), because later we will need to
// check that we do not try to rename a column to an already a name that already exists.
if cCol, ok := currentColumns.Get(tName); ok {
if checkType && !d.equalColumns(cCol, tCol) {
d.changes.Add(&ChangeColumnTypeOp{
TableName: target.GetName(),
Column: tName,
From: cCol,
To: d.makeTargetColDef(cCol, tCol),
})
}
continue
}
// Column tName does not exist in the database -- it's been either renamed or added.
// Find renamed columns first.
for cName, cCol := range currentColumns.FromOldest() {
// Cannot rename if a column with this name already exists or the types differ.
if _, exists := targetColumns.Get(cName); exists || !d.equalColumns(tCol, cCol) {
continue
}
d.changes.Add(&RenameColumnOp{
TableName: target.GetName(),
OldName: cName,
NewName: tName,
})
d.refMap.RenameColumn(target.GetName(), cName, tName)
currentColumns.Delete(cName) // no need to check this column again
// Update primary key definition to avoid superficially recreating the constraint.
current.GetPrimaryKey().Columns.Replace(cName, tName)
continue ChangeRename
}
d.changes.Add(&AddColumnOp{
TableName: target.GetName(),
ColumnName: tName,
Column: tCol,
})
}
// Drop columns which do not exist in the target schema and were not renamed.
for cName, cCol := range currentColumns.FromOldest() {
if _, keep := targetColumns.Get(cName); !keep {
d.changes.Add(&DropColumnOp{
TableName: target.GetName(),
ColumnName: cName,
Column: cCol,
})
}
}
}
func (d *detector) detectConstraintChanges(current, target sqlschema.Table) {
Add:
for _, want := range target.GetUniqueConstraints() {
for _, got := range current.GetUniqueConstraints() {
if got.Equals(want) {
continue Add
}
}
d.changes.Add(&AddUniqueConstraintOp{
TableName: target.GetName(),
Unique: want,
})
}
Drop:
for _, got := range current.GetUniqueConstraints() {
for _, want := range target.GetUniqueConstraints() {
if got.Equals(want) {
continue Drop
}
}
d.changes.Add(&DropUniqueConstraintOp{
TableName: target.GetName(),
Unique: got,
})
}
targetPK := target.GetPrimaryKey()
currentPK := current.GetPrimaryKey()
// Detect primary key changes
if targetPK == nil && currentPK == nil {
return
}
switch {
case targetPK == nil && currentPK != nil:
d.changes.Add(&DropPrimaryKeyOp{
TableName: target.GetName(),
PrimaryKey: *currentPK,
})
case currentPK == nil && targetPK != nil:
d.changes.Add(&AddPrimaryKeyOp{
TableName: target.GetName(),
PrimaryKey: *targetPK,
})
case targetPK.Columns != currentPK.Columns:
d.changes.Add(&ChangePrimaryKeyOp{
TableName: target.GetName(),
Old: *currentPK,
New: *targetPK,
})
}
}
func newDetector(got, want sqlschema.Database, opts ...diffOption) *detector {
cfg := &detectorConfig{
cmpType: func(c1, c2 sqlschema.Column) bool {
return c1.GetSQLType() == c2.GetSQLType() && c1.GetVarcharLen() == c2.GetVarcharLen()
},
}
for _, opt := range opts {
opt(cfg)
}
return &detector{
current: got,
target: want,
refMap: newRefMap(got.GetForeignKeys()),
cmpType: cfg.cmpType,
}
}
type diffOption func(*detectorConfig)
func withCompareTypeFunc(f CompareTypeFunc) diffOption {
return func(cfg *detectorConfig) {
cfg.cmpType = f
}
}
// detectorConfig controls how differences in the model states are resolved.
type detectorConfig struct {
cmpType CompareTypeFunc
}
// detector may modify the passed database schemas, so it isn't safe to re-use them.
type detector struct {
// current state represents the existing database schema.
current sqlschema.Database
// target state represents the database schema defined in bun models.
target sqlschema.Database
changes changeset
refMap refMap
// cmpType determines column type equivalence.
// Default is direct comparison with '==' operator, which is inaccurate
// due to the existence of dialect-specific type aliases. The caller
// should pass a concrete InspectorDialect.EquuivalentType for robust comparison.
cmpType CompareTypeFunc
}
// canRename checks if t1 can be renamed to t2.
func (d detector) canRename(t1, t2 sqlschema.Table) bool {
return t1.GetSchema() == t2.GetSchema() && equalSignatures(t1, t2, d.equalColumns)
}
func (d detector) equalColumns(col1, col2 sqlschema.Column) bool {
return d.cmpType(col1, col2) &&
col1.GetDefaultValue() == col2.GetDefaultValue() &&
col1.GetIsNullable() == col2.GetIsNullable() &&
col1.GetIsAutoIncrement() == col2.GetIsAutoIncrement() &&
col1.GetIsIdentity() == col2.GetIsIdentity()
}
func (d detector) makeTargetColDef(current, target sqlschema.Column) sqlschema.Column {
// Avoid unneccessary type-change migrations if the types are equivalent.
if d.cmpType(current, target) {
target = &sqlschema.BaseColumn{
Name: target.GetName(),
DefaultValue: target.GetDefaultValue(),
IsNullable: target.GetIsNullable(),
IsAutoIncrement: target.GetIsAutoIncrement(),
IsIdentity: target.GetIsIdentity(),
SQLType: current.GetSQLType(),
VarcharLen: current.GetVarcharLen(),
}
}
return target
}
type CompareTypeFunc func(sqlschema.Column, sqlschema.Column) bool
// equalSignatures determines if two tables have the same "signature".
func equalSignatures(t1, t2 sqlschema.Table, eq CompareTypeFunc) bool {
sig1 := newSignature(t1, eq)
sig2 := newSignature(t2, eq)
return sig1.Equals(sig2)
}
// signature is a set of column definitions, which allows "relation/name-agnostic" comparison between them;
// meaning that two columns are considered equal if their types are the same.
type signature struct {
// underlying stores the number of occurences for each unique column type.
// It helps to account for the fact that a table might have multiple columns that have the same type.
underlying map[sqlschema.BaseColumn]int
eq CompareTypeFunc
}
func newSignature(t sqlschema.Table, eq CompareTypeFunc) signature {
s := signature{
underlying: make(map[sqlschema.BaseColumn]int),
eq: eq,
}
s.scan(t)
return s
}
// scan iterates over table's field and counts occurrences of each unique column definition.
func (s *signature) scan(t sqlschema.Table) {
for _, icol := range t.GetColumns().FromOldest() {
scanCol := icol.(*sqlschema.BaseColumn)
// This is slightly more expensive than if the columns could be compared directly
// and we always did s.underlying[col]++, but we get type-equivalence in return.
col, count := s.getCount(*scanCol)
if count == 0 {
s.underlying[*scanCol] = 1
} else {
s.underlying[col]++
}
}
}
// getCount uses CompareTypeFunc to find a column with the same (equivalent) SQL type
// and returns its count. Count 0 means there are no columns with of this type.
func (s *signature) getCount(keyCol sqlschema.BaseColumn) (key sqlschema.BaseColumn, count int) {
for col, cnt := range s.underlying {
if s.eq(&col, &keyCol) {
return col, cnt
}
}
return keyCol, 0
}
// Equals returns true if 2 signatures share an identical set of columns.
func (s *signature) Equals(other signature) bool {
if len(s.underlying) != len(other.underlying) {
return false
}
for col, count := range s.underlying {
if _, countOther := other.getCount(col); countOther != count {
return false
}
}
return true
}
// refMap is a utility for tracking superficial changes in foreign keys,
// which do not require any modificiation in the database.
// Modern SQL dialects automatically updated foreign key constraints whenever
// a column or a table is renamed. Detector can use refMap to ignore any
// differences in foreign keys which were caused by renamed column/table.
type refMap map[*sqlschema.ForeignKey]string
func newRefMap(fks map[sqlschema.ForeignKey]string) refMap {
rm := make(map[*sqlschema.ForeignKey]string)
for fk, name := range fks {
rm[&fk] = name
}
return rm
}
// RenameT updates table name in all foreign key definions which depend on it.
func (rm refMap) RenameTable(tableName string, newName string) {
for fk := range rm {
switch tableName {
case fk.From.TableName:
fk.From.TableName = newName
case fk.To.TableName:
fk.To.TableName = newName
}
}
}
// RenameColumn updates column name in all foreign key definions which depend on it.
func (rm refMap) RenameColumn(tableName string, column, newName string) {
for fk := range rm {
if tableName == fk.From.TableName {
fk.From.Column.Replace(column, newName)
}
if tableName == fk.To.TableName {
fk.To.Column.Replace(column, newName)
}
}
}
// Deref returns copies of ForeignKey values to a map.
func (rm refMap) Deref() map[sqlschema.ForeignKey]string {
out := make(map[sqlschema.ForeignKey]string)
for fk, name := range rm {
out[*fk] = name
}
return out
}