diff --git a/database/cassandra/cassandra.go b/database/cassandra/cassandra.go index 74eecc98e..50f141a33 100644 --- a/database/cassandra/cassandra.go +++ b/database/cassandra/cassandra.go @@ -42,6 +42,15 @@ type Config struct { KeyspaceName string MultiStatementEnabled bool MultiStatementMaxSize int + + Triggers map[string]func(response interface{}) error +} + +type TriggerResponse struct { + Driver *Cassandra + Config *Config + Trigger string + Detail interface{} } type Cassandra struct { @@ -198,6 +207,27 @@ func (c *Cassandra) Close() error { return nil } +func (c *Cassandra) AddTriggers(t map[string]func(response interface{}) error) { + c.config.Triggers = t +} + +func (c *Cassandra) Trigger(name string, detail interface{}) error { + if c.config.Triggers == nil { + return nil + } + + if trigger, ok := c.config.Triggers[name]; ok { + return trigger(TriggerResponse{ + Driver: c, + Config: c.config, + Trigger: name, + Detail: detail, + }) + } + + return nil +} + func (c *Cassandra) Lock() error { if !c.isLocked.CAS(false, true) { return database.ErrLocked @@ -220,10 +250,22 @@ func (c *Cassandra) Run(migration io.Reader) error { if tq == "" { return true } + if e := c.Trigger(database.TrigRunPre, struct { + Query string + }{Query: tq}); e != nil { + err = database.Error{OrigErr: e, Err: "failed to trigger RunPre"} + return false + } if e := c.session.Query(tq).Exec(); e != nil { err = database.Error{OrigErr: e, Err: "migration failed", Query: m} return false } + if e := c.Trigger(database.TrigRunPost, struct { + Query string + }{Query: tq}); e != nil { + err = database.Error{OrigErr: e, Err: "failed to trigger RunPost"} + return false + } return true }); e != nil { return e @@ -235,15 +277,32 @@ func (c *Cassandra) Run(migration io.Reader) error { if err != nil { return err } + if err := c.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(migr)}); err != nil { + return database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } // run migration if err := c.session.Query(string(migr)).Exec(); err != nil { // TODO: cast to Cassandra error and get line number return database.Error{OrigErr: err, Err: "migration failed", Query: migr} } + if err := c.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(migr)}); err != nil { + return database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } return nil } func (c *Cassandra) SetVersion(version int, dirty bool) error { + if err := c.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + // DELETE instead of TRUNCATE because AWS Keyspaces does not support it // see: https://docs.aws.amazon.com/keyspaces/latest/devguide/cassandra-apis.html squery := `SELECT version FROM "` + c.config.MigrationsTable + `"` @@ -269,6 +328,13 @@ func (c *Cassandra) SetVersion(version int, dirty bool) error { } } + if err := c.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + return nil } @@ -324,6 +390,10 @@ func (c *Cassandra) ensureVersionTable() (err error) { } }() + if err := c.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + err = c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec() if err != nil { return err @@ -331,6 +401,11 @@ func (c *Cassandra) ensureVersionTable() (err error) { if _, _, err = c.Version(); err != nil { return err } + + if err := c.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } diff --git a/database/clickhouse/clickhouse.go b/database/clickhouse/clickhouse.go index d2b65c0ce..58bf8bf70 100644 --- a/database/clickhouse/clickhouse.go +++ b/database/clickhouse/clickhouse.go @@ -34,6 +34,8 @@ type Config struct { MigrationsTableEngine string MultiStatementEnabled bool MultiStatementMaxSize int + + Triggers map[string]func(response interface{}) error } func init() { @@ -67,6 +69,13 @@ type ClickHouse struct { isLocked atomic.Bool } +type TriggerResponse struct { + Driver *ClickHouse + Config *Config + Trigger string + Detail interface{} +} + func (ch *ClickHouse) Open(dsn string) (database.Driver, error) { purl, err := url.Parse(dsn) if err != nil { @@ -141,10 +150,22 @@ func (ch *ClickHouse) Run(r io.Reader) error { if tq == "" { return true } + if e := ch.Trigger(database.TrigRunPre, struct { + Query string + }{Query: tq}); e != nil { + err = database.Error{OrigErr: e, Err: "failed to trigger RunPre"} + return false + } if _, e := ch.conn.Exec(string(m)); e != nil { err = database.Error{OrigErr: e, Err: "migration failed", Query: m} return false } + if e := ch.Trigger(database.TrigRunPost, struct { + Query string + }{Query: tq}); e != nil { + err = database.Error{OrigErr: e, Err: "failed to trigger RunPost"} + return false + } return true }); e != nil { return e @@ -157,10 +178,22 @@ func (ch *ClickHouse) Run(r io.Reader) error { return err } + if err := ch.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(migration)}); err != nil { + return database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } + if _, err := ch.conn.Exec(string(migration)); err != nil { return database.Error{OrigErr: err, Err: "migration failed", Query: migration} } + if err := ch.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(migration)}); err != nil { + return database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } + return nil } func (ch *ClickHouse) Version() (int, bool, error) { @@ -180,7 +213,7 @@ func (ch *ClickHouse) Version() (int, bool, error) { func (ch *ClickHouse) SetVersion(version int, dirty bool) error { var ( - bool = func(v bool) uint8 { + booln = func(v bool) uint8 { if v { return 1 } @@ -192,11 +225,25 @@ func (ch *ClickHouse) SetVersion(version int, dirty bool) error { return err } + if err := ch.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := "INSERT INTO " + ch.config.MigrationsTable + " (version, dirty, sequence) VALUES (?, ?, ?)" - if _, err := tx.Exec(query, version, bool(dirty), time.Now().UnixNano()); err != nil { + if _, err := tx.Exec(query, version, booln(dirty), time.Now().UnixNano()); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } + if err := ch.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + return tx.Commit() } @@ -228,9 +275,16 @@ func (ch *ClickHouse) ensureVersionTable() (err error) { return &database.Error{OrigErr: err, Query: []byte(query)} } } else { + if err := ch.Trigger(database.TrigVersionTableExists, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTableExists"} + } return nil } + if err := ch.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + // if not, create the empty migration table if len(ch.config.ClusterName) > 0 { query = fmt.Sprintf(` @@ -255,6 +309,11 @@ func (ch *ClickHouse) ensureVersionTable() (err error) { if _, err := ch.conn.Exec(query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } + + if err := ch.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } @@ -306,6 +365,27 @@ func (ch *ClickHouse) Unlock() error { } func (ch *ClickHouse) Close() error { return ch.conn.Close() } +func (ch *ClickHouse) AddTriggers(t map[string]func(response interface{}) error) { + ch.config.Triggers = t +} + +func (ch *ClickHouse) Trigger(name string, detail interface{}) error { + if ch.config.Triggers == nil { + return nil + } + + if trigger, ok := ch.config.Triggers[name]; ok { + return trigger(TriggerResponse{ + Driver: ch, + Config: ch.config, + Trigger: name, + Detail: detail, + }) + } + + return nil +} + // Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611 func quoteIdentifier(name string) string { end := strings.IndexRune(name, 0) diff --git a/database/cockroachdb/cockroachdb.go b/database/cockroachdb/cockroachdb.go index 699b3facd..0d0d680da 100644 --- a/database/cockroachdb/cockroachdb.go +++ b/database/cockroachdb/cockroachdb.go @@ -37,6 +37,8 @@ type Config struct { LockTable string ForceLock bool DatabaseName string + + Triggers map[string]func(response interface{}) error } type CockroachDb struct { @@ -47,6 +49,13 @@ type CockroachDb struct { config *Config } +type TriggerResponse struct { + Driver *CockroachDb + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -144,6 +153,27 @@ func (c *CockroachDb) Close() error { return c.db.Close() } +func (c *CockroachDb) AddTriggers(t map[string]func(response interface{}) error) { + c.config.Triggers = t +} + +func (c *CockroachDb) Trigger(name string, detail interface{}) error { + if c.config.Triggers == nil { + return nil + } + + if trigger, ok := c.config.Triggers[name]; ok { + return trigger(TriggerResponse{ + Driver: c, + Config: c.config, + Trigger: name, + Detail: detail, + }) + } + + return nil +} + // Locking is done manually with a separate lock table. Implementing advisory locks in CRDB is being discussed // See: https://github.com/cockroachdb/cockroach/issues/13546 func (c *CockroachDb) Lock() error { @@ -218,15 +248,32 @@ func (c *CockroachDb) Run(migration io.Reader) error { // run migration query := string(migr[:]) + if err := c.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } if _, err := c.db.Exec(query); err != nil { return database.Error{OrigErr: err, Err: "migration failed", Query: migr} } + if err := c.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } return nil } func (c *CockroachDb) SetVersion(version int, dirty bool) error { return crdb.ExecuteTx(context.Background(), c.db, nil, func(tx *sql.Tx) error { + if err := c.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return err + } + if _, err := tx.Exec(`DELETE FROM "` + c.config.MigrationsTable + `"`); err != nil { return err } @@ -240,6 +287,13 @@ func (c *CockroachDb) SetVersion(version int, dirty bool) error { } } + if err := c.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return err + } + return nil }) } @@ -333,18 +387,31 @@ func (c *CockroachDb) ensureVersionTable() (err error) { return &database.Error{OrigErr: err, Query: []byte(query)} } if count == 1 { + if err := c.Trigger(database.TrigVersionTableExists, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTableExists"} + } return nil } + if err := c.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + // if not, create the empty migration table query = `CREATE TABLE "` + c.config.MigrationsTable + `" (version INT NOT NULL PRIMARY KEY, dirty BOOL NOT NULL)` if _, err := c.db.Exec(query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } + + if err := c.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } func (c *CockroachDb) ensureLockTable() error { + // check if lock table exists var count int query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` diff --git a/database/driver.go b/database/driver.go index 11268e6b9..8d756aef6 100644 --- a/database/driver.go +++ b/database/driver.go @@ -19,6 +19,14 @@ var ( const NilVersion int = -1 +const TrigRunPre string = "RunPre" +const TrigRunPost string = "RunPost" +const TrigSetVersionPre string = "SetVersionPre" +const TrigSetVersionPost string = "SetVersionPost" +const TrigVersionTableExists string = "VersionTableExists" +const TrigVersionTablePre string = "VersionTablePre" +const TrigVersionTablePost string = "VersionTablePost" + var driversMu sync.RWMutex var drivers = make(map[string]Driver) @@ -52,6 +60,13 @@ type Driver interface { // Migrate will call this function only once per instance. Close() error + // AddTriggers adds triggers to the database. The map key is the trigger name + AddTriggers(t map[string]func(response interface{}) error) + + // Trigger is called when a trigger is fired. The name is the trigger name + // and detail is the trigger detail. + Trigger(name string, detail interface{}) error + // Lock should acquire a database lock so that only one migration process // can run at a time. Migrate will call this function before Run is called. // If the implementation can't provide this functionality, return nil. diff --git a/database/driver_test.go b/database/driver_test.go index 7880f3208..1b157cde4 100644 --- a/database/driver_test.go +++ b/database/driver_test.go @@ -28,6 +28,12 @@ func (m *mockDriver) Close() error { return nil } +func (m *mockDriver) AddTriggers(t map[string]func(detail interface{}) error) {} + +func (m *mockDriver) Trigger(name string, detail interface{}) error { + return nil +} + func (m *mockDriver) Lock() error { return nil } diff --git a/database/firebird/firebird.go b/database/firebird/firebird.go index e15ea96b8..78973a4a2 100644 --- a/database/firebird/firebird.go +++ b/database/firebird/firebird.go @@ -31,6 +31,8 @@ var ( type Config struct { DatabaseName string MigrationsTable string + + Triggers map[string]func(response interface{}) error } type Firebird struct { @@ -43,6 +45,13 @@ type Firebird struct { config *Config } +type TriggerResponse struct { + Driver *Firebird + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -106,6 +115,27 @@ func (f *Firebird) Close() error { return nil } +func (f *Firebird) AddTriggers(t map[string]func(response interface{}) error) { + f.config.Triggers = t +} + +func (f *Firebird) Trigger(name string, detail interface{}) error { + if f.config.Triggers == nil { + return nil + } + + if trigger, ok := f.config.Triggers[name]; ok { + return trigger(TriggerResponse{ + Driver: f, + Config: f.config, + Trigger: name, + Detail: detail, + }) + } + + return nil +} + func (f *Firebird) Lock() error { if !f.isLocked.CAS(false, true) { return database.ErrLocked @@ -128,9 +158,19 @@ func (f *Firebird) Run(migration io.Reader) error { // run migration query := string(migr[:]) + if err := f.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } if _, err := f.conn.ExecContext(context.Background(), query); err != nil { return database.Error{OrigErr: err, Err: "migration failed", Query: migr} } + if err := f.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } return nil } @@ -140,6 +180,13 @@ func (f *Firebird) SetVersion(version int, dirty bool) error { // for failed down migration on the first migration // See: https://github.com/golang-migrate/migrate/issues/330 + if err := f.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + // TODO: parameterize this SQL statement // https://firebirdsql.org/refdocs/langrefupd20-execblock.html // VALUES (?, ?) doesn't work @@ -153,6 +200,13 @@ func (f *Firebird) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Query: []byte(query)} } + if err := f.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + return nil } @@ -231,6 +285,10 @@ func (f *Firebird) ensureVersionTable() (err error) { } }() + if err := f.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN if (not exists(select 1 from rdb$relations where rdb$relation_name = '%v')) then execute statement 'create table "%v" (version bigint not null primary key, dirty smallint not null)'; @@ -241,6 +299,10 @@ func (f *Firebird) ensureVersionTable() (err error) { return &database.Error{OrigErr: err, Query: []byte(query)} } + if err := f.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } diff --git a/database/mongodb/mongodb.go b/database/mongodb/mongodb.go index 3a9a6be9e..ede3b3c6f 100644 --- a/database/mongodb/mongodb.go +++ b/database/mongodb/mongodb.go @@ -59,6 +59,8 @@ type Config struct { MigrationsCollection string TransactionMode bool Locking Locking + + Triggers map[string]func(response interface{}) error } type versionInfo struct { Version int `bson:"version"` @@ -75,6 +77,13 @@ type findFilter struct { Key int `bson:"locking_key"` } +type TriggerResponse struct { + Driver *Mongo + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *mongo.Client, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -216,6 +225,13 @@ func parseInt(urlParam string, defaultValue int) (int, error) { return defaultValue, nil } func (m *Mongo) SetVersion(version int, dirty bool) error { + if err := m.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + migrationsCollection := m.db.Collection(m.config.MigrationsCollection) if err := migrationsCollection.Drop(context.TODO()); err != nil { return &database.Error{OrigErr: err, Err: "drop migrations collection failed"} @@ -224,6 +240,14 @@ func (m *Mongo) SetVersion(version int, dirty bool) error { if err != nil { return &database.Error{OrigErr: err, Err: "save version failed"} } + + if err := m.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + return nil } @@ -250,6 +274,11 @@ func (m *Mongo) Run(migration io.Reader) error { if err != nil { return fmt.Errorf("unmarshaling json error: %s", err) } + if err := m.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(migr)}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } if m.config.TransactionMode { if err := m.executeCommandsWithTransaction(context.TODO(), cmds); err != nil { return err @@ -259,6 +288,11 @@ func (m *Mongo) Run(migration io.Reader) error { return err } } + if err := m.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(migr)}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } return nil } @@ -297,6 +331,27 @@ func (m *Mongo) Close() error { return m.client.Disconnect(context.TODO()) } +func (m *Mongo) AddTriggers(t map[string]func(response interface{}) error) { + m.config.Triggers = t +} + +func (m *Mongo) Trigger(name string, detail interface{}) error { + if m.config.Triggers == nil { + return nil + } + + if trigger, ok := m.config.Triggers[name]; ok { + return trigger(TriggerResponse{ + Driver: m, + Config: m.config, + Trigger: name, + Detail: detail, + }) + } + + return nil +} + func (m *Mongo) Drop() error { return m.db.Drop(context.TODO()) } @@ -336,9 +391,15 @@ func (m *Mongo) ensureVersionTable() (err error) { if err != nil { return err } + if err := m.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } if _, _, err = m.Version(); err != nil { return err } + if err := m.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } return nil } diff --git a/database/mysql/examples/triggers/main.go b/database/mysql/examples/triggers/main.go new file mode 100644 index 000000000..7c69ff42c --- /dev/null +++ b/database/mysql/examples/triggers/main.go @@ -0,0 +1,120 @@ +package main + +import ( + "context" + "database/sql" + "fmt" + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database" + "github.com/golang-migrate/migrate/v4/database/mysql" + _ "github.com/golang-migrate/migrate/v4/source/file" + "os" +) + +type App struct { + Connection *sql.Conn + MigrationTable string + HistoryID *int64 +} + +func main() { + db, err := sql.Open("mysql", "root:root@tcp(localhost:3306)/db?multiStatements=true") + if err != nil { + fmt.Printf("%v\n", err) + os.Exit(1) + } + ctx := context.Background() + conn, err := db.Conn(ctx) + if err != nil { + fmt.Printf("%v\n", err) + os.Exit(1) + } + defer conn.Close() + defer db.Close() + + app := &App{ + Connection: conn, + MigrationTable: mysql.DefaultMigrationsTable, + } + + databaseDrv, err := mysql.WithConnection(ctx, conn, &mysql.Config{ + DatabaseName: "db", + Triggers: map[string]func(response interface{}) error{ + database.TrigVersionTableExists: app.MigrationHistoryTable, + database.TrigVersionTablePost: app.MigrationHistoryTable, + database.TrigRunPost: app.DatabaseRunPost, + }, + }) + if err != nil { + fmt.Printf("%v\n", err) + os.Exit(1) + } + m, err := migrate.NewFromOptions(migrate.Options{ + DatabaseInstance: databaseDrv, + DatabaseName: "db", + SourceURL: "file://migrations", + MigrateTriggers: map[string]func(response migrate.TriggerResponse) error{ + migrate.TrigRunMigrationVersionPre: app.RunMigrationVersionPre, + migrate.TrigRunMigrationVersionPost: app.RunMigrationVersionPost, + }, + }) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + err = m.Up() + //err = m.Down() + if err != nil { + fmt.Printf("%v\n", err) + return + } +} + +func (app *App) MigrationHistoryTable(response interface{}) error { + r, _ := response.(mysql.TriggerResponse) + fmt.Printf("Executing database trigger %s\n, %v\n", r.Trigger, r) + query := "CREATE TABLE IF NOT EXISTS `" + app.MigrationTable + "_history` (`id` bigint not null primary key auto_increment, `version` bigint not null, `target` bigint, identifier varchar(255), `dirty` tinyint not null, migration text, `timestamp` datetime not null)" + _, err := app.Connection.ExecContext(context.TODO(), query) + return err +} + +func (app *App) DatabaseRunPost(response interface{}) error { + r, _ := response.(mysql.TriggerResponse) + detail := r.Detail.(struct{ Query string }) + fmt.Printf("Executing database trigger %s\n, %v\n", r.Trigger, r.Detail) + query := "UPDATE `" + app.MigrationTable + "_history` SET `migration` = ? WHERE `id` = ?" + _, err := app.Connection.ExecContext(context.TODO(), query, detail.Query, app.HistoryID) + if err != nil { + fmt.Printf("Error updating migration history: %v\n", err) + return err + } + + return nil +} + +func (app *App) RunMigrationVersionPre(r migrate.TriggerResponse) error { + detail := r.Detail.(struct{ Migration *migrate.Migration }) + fmt.Printf("Executing migration trigger %s\n, %v\n", r.Trigger, r.Detail) + query := "INSERT INTO `" + app.MigrationTable + "_history` (`version`, `identifier`, `target`, `dirty`, `timestamp`) VALUES (?, ?, ?, 1, NOW())" + _, err := app.Connection.ExecContext(context.TODO(), query, detail.Migration.Version, detail.Migration.Identifier, detail.Migration.TargetVersion) + if err != nil { + fmt.Printf("Error inserting migration history: %v\n", err) + return err + } + query = "SELECT LAST_INSERT_ID()" + row := app.Connection.QueryRowContext(context.TODO(), query) + return row.Scan(&app.HistoryID) +} + +func (app *App) RunMigrationVersionPost(r migrate.TriggerResponse) error { + fmt.Printf("Executing migration trigger %s\n, %v\n", r.Trigger, r.Detail) + query := "UPDATE `" + app.MigrationTable + "_history` SET `dirty` = 0, `timestamp` = NOW() WHERE `id` = ?" + _, err := app.Connection.ExecContext(context.TODO(), query, app.HistoryID) + if err != nil { + fmt.Printf("Error updating migration history: %v\n", err) + return err + } + app.HistoryID = nil + + return nil +} diff --git a/database/mysql/examples/triggers/migrations/20250101_tusers.down.sql b/database/mysql/examples/triggers/migrations/20250101_tusers.down.sql new file mode 100644 index 000000000..3ad31301a --- /dev/null +++ b/database/mysql/examples/triggers/migrations/20250101_tusers.down.sql @@ -0,0 +1 @@ +DROP TABLE tusers; \ No newline at end of file diff --git a/database/mysql/examples/triggers/migrations/20250101_tusers.up.sql b/database/mysql/examples/triggers/migrations/20250101_tusers.up.sql new file mode 100644 index 000000000..0f9c910b4 --- /dev/null +++ b/database/mysql/examples/triggers/migrations/20250101_tusers.up.sql @@ -0,0 +1,6 @@ +CREATE TABLE tusers ( + id SERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL, + email VARCHAR(100) UNIQUE NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); \ No newline at end of file diff --git a/database/mysql/examples/triggers/migrations/20250102_tusers_alter.down.sql b/database/mysql/examples/triggers/migrations/20250102_tusers_alter.down.sql new file mode 100644 index 000000000..3058138c8 --- /dev/null +++ b/database/mysql/examples/triggers/migrations/20250102_tusers_alter.down.sql @@ -0,0 +1,4 @@ +ALTER TABLE tusers + DROP COLUMN last_login, + DROP COLUMN status, + DROP COLUMN profile_picture; \ No newline at end of file diff --git a/database/mysql/examples/triggers/migrations/20250102_tusers_alter.up.sql b/database/mysql/examples/triggers/migrations/20250102_tusers_alter.up.sql new file mode 100644 index 000000000..f7d9e5709 --- /dev/null +++ b/database/mysql/examples/triggers/migrations/20250102_tusers_alter.up.sql @@ -0,0 +1,4 @@ +ALTER TABLE tusers + ADD COLUMN last_login TIMESTAMP, + ADD COLUMN status VARCHAR(20) DEFAULT 'active', + ADD COLUMN profile_picture VARCHAR(255); \ No newline at end of file diff --git a/database/mysql/mysql.go b/database/mysql/mysql.go index 711ba5187..b857a3ce6 100644 --- a/database/mysql/mysql.go +++ b/database/mysql/mysql.go @@ -43,6 +43,15 @@ type Config struct { DatabaseName string NoLock bool StatementTimeout time.Duration + + Triggers map[string]func(response interface{}) error +} + +type TriggerResponse struct { + Driver *Mysql + Config *Config + Trigger string + Detail interface{} } type Mysql struct { @@ -283,6 +292,27 @@ func (m *Mysql) Close() error { return nil } +func (m *Mysql) AddTriggers(t map[string]func(response interface{}) error) { + m.config.Triggers = t +} + +func (m *Mysql) Trigger(name string, detail interface{}) error { + if m.config.Triggers == nil { + return nil + } + + if trigger, ok := m.config.Triggers[name]; ok { + return trigger(TriggerResponse{ + Driver: m, + Config: m.config, + Trigger: name, + Detail: detail, + }) + } + + return nil +} + func (m *Mysql) Lock() error { return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error { if m.config.NoLock { @@ -347,9 +377,19 @@ func (m *Mysql) Run(migration io.Reader) error { } query := string(migr[:]) + if err := m.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } if _, err := m.conn.ExecContext(ctx, query); err != nil { return database.Error{OrigErr: err, Err: "migration failed", Query: migr} } + if err := m.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } return nil } @@ -360,6 +400,16 @@ func (m *Mysql) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } + if err := m.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := "DELETE FROM `" + m.config.MigrationsTable + "` LIMIT 1" if _, err := tx.ExecContext(context.Background(), query); err != nil { if errRollback := tx.Rollback(); errRollback != nil { @@ -381,6 +431,16 @@ func (m *Mysql) SetVersion(version int, dirty bool) error { } } + if err := m.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } @@ -486,14 +546,25 @@ func (m *Mysql) ensureVersionTable() (err error) { return &database.Error{OrigErr: err, Query: []byte(query)} } } else { + if err := m.Trigger(database.TrigVersionTableExists, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTableExists"} + } return nil } + if err := m.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + // if not, create the empty migration table query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)" if _, err := m.conn.ExecContext(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } + + if err := m.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } return nil } diff --git a/database/neo4j/neo4j.go b/database/neo4j/neo4j.go index 179e0da60..3dff54853 100644 --- a/database/neo4j/neo4j.go +++ b/database/neo4j/neo4j.go @@ -34,6 +34,8 @@ type Config struct { MigrationsLabel string MultiStatement bool MultiStatementMaxSize int + + Triggers map[string]func(response interface{}) error } type Neo4j struct { @@ -44,6 +46,13 @@ type Neo4j struct { config *Config } +type TriggerResponse struct { + Driver *Neo4j + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(driver neo4j.Driver, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -118,6 +127,27 @@ func (n *Neo4j) Close() error { return n.driver.Close() } +func (n *Neo4j) AddTriggers(t map[string]func(response interface{}) error) { + n.config.Triggers = t +} + +func (n *Neo4j) Trigger(name string, detail interface{}) error { + if n.config.Triggers == nil { + return nil + } + + if trigger, ok := n.config.Triggers[name]; ok { + return trigger(TriggerResponse{ + Driver: n, + Config: n.config, + Trigger: name, + Detail: detail, + }) + } + + return nil +} + // local locking in order to pass tests, Neo doesn't support database locking func (n *Neo4j) Lock() error { if !atomic.CompareAndSwapUint32(&n.lock, 0, 1) { @@ -158,11 +188,26 @@ func (n *Neo4j) Run(migration io.Reader) (err error) { return true } + if err = n.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(trimStmt)}); err != nil { + stmtRunErr = err + return false + } + result, err := transaction.Run(string(trimStmt), nil) if _, err := neo4j.Collect(result, err); err != nil { stmtRunErr = err return false } + + if err = n.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(trimStmt)}); err != nil { + stmtRunErr = err + return false + } + return true }); err != nil { return nil, err @@ -176,8 +221,19 @@ func (n *Neo4j) Run(migration io.Reader) (err error) { if err != nil { return err } - - _, err = neo4j.Collect(session.Run(string(body[:]), nil)) + if err = n.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(body[:])}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } + if _, err = neo4j.Collect(session.Run(string(body[:]), nil)); err != nil { + return err + } + if err = n.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(body[:])}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } return err } @@ -192,12 +248,24 @@ func (n *Neo4j) SetVersion(version int, dirty bool) (err error) { } }() + if err := n.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } query := fmt.Sprintf("MERGE (sm:%s {version: $version}) SET sm.dirty = $dirty, sm.ts = datetime()", n.config.MigrationsLabel) _, err = neo4j.Collect(session.Run(query, map[string]interface{}{"version": version, "dirty": dirty})) if err != nil { return err } + if err := n.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } return nil } @@ -292,12 +360,21 @@ func (n *Neo4j) ensureVersionConstraint() (err error) { return err } if len(res) == 1 { + if err := n.Trigger(database.TrigVersionTableExists, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTableExists"} + } return nil } + if err := n.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } query := fmt.Sprintf("CREATE CONSTRAINT ON (a:%s) ASSERT a.version IS UNIQUE", n.config.MigrationsLabel) if _, err := neo4j.Collect(session.Run(query, nil)); err != nil { return err } + if err := n.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } return nil } diff --git a/database/pgx/pgx.go b/database/pgx/pgx.go index efe8bea80..ce1fde69b 100644 --- a/database/pgx/pgx.go +++ b/database/pgx/pgx.go @@ -64,6 +64,8 @@ type Config struct { MigrationsTableQuoted bool MultiStatementEnabled bool MultiStatementMaxSize int + + Triggers map[string]func(response interface{}) error } type Postgres struct { @@ -76,6 +78,13 @@ type Postgres struct { config *Config } +type TriggerResponse struct { + Driver *Postgres + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -247,6 +256,27 @@ func (p *Postgres) Close() error { return nil } +func (p *Postgres) AddTriggers(t map[string]func(response interface{}) error) { + p.config.Triggers = t +} + +func (p *Postgres) Trigger(name string, detail interface{}) error { + if p.config.Triggers == nil { + return nil + } + + if trigger, ok := p.config.Triggers[name]; ok { + return trigger(TriggerResponse{ + Driver: p, + Config: p.config, + Trigger: name, + Detail: detail, + }) + } + + return nil +} + func (p *Postgres) Lock() error { return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error { switch p.config.LockStrategy { @@ -363,9 +393,19 @@ func (p *Postgres) Run(migration io.Reader) error { if p.config.MultiStatementEnabled { var err error if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { + if e := p.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(m)}); e != nil { + return false + } if err = p.runStatement(m); err != nil { return false } + if e := p.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(m)}); e != nil { + return false + } return true }); e != nil { return e @@ -376,7 +416,20 @@ func (p *Postgres) Run(migration io.Reader) error { if err != nil { return err } - return p.runStatement(migr) + if err = p.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(migr)}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } + if err = p.runStatement(migr); err != nil { + return err + } + if err = p.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(migr)}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } + return nil } func (p *Postgres) runStatement(statement []byte) error { @@ -452,6 +505,16 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } + if err := p.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) if _, err := tx.Exec(query); err != nil { if errRollback := tx.Rollback(); errRollback != nil { @@ -473,6 +536,16 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { } } + if err := p.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } @@ -579,14 +652,25 @@ func (p *Postgres) ensureVersionTable() (err error) { } if count == 1 { + if err := p.Trigger(database.TrigVersionTableExists, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTableExists"} + } return nil } + if err := p.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)` if _, err = p.conn.ExecContext(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } + if err := p.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } diff --git a/database/pgx/v5/pgx.go b/database/pgx/v5/pgx.go index 303174495..eb0e81864 100644 --- a/database/pgx/v5/pgx.go +++ b/database/pgx/v5/pgx.go @@ -52,6 +52,8 @@ type Config struct { MigrationsTableQuoted bool MultiStatementEnabled bool MultiStatementMaxSize int + + Triggers map[string]func(response interface{}) error } type Postgres struct { @@ -64,6 +66,13 @@ type Postgres struct { config *Config } +type TriggerResponse struct { + Driver *Postgres + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -218,6 +227,27 @@ func (p *Postgres) Close() error { return nil } +func (p *Postgres) AddTriggers(t map[string]func(response interface{}) error) { + p.config.Triggers = t +} + +func (p *Postgres) Trigger(name string, detail interface{}) error { + if p.config.Triggers == nil { + return nil + } + + if trigger, ok := p.config.Triggers[name]; ok { + return trigger(TriggerResponse{ + Driver: p, + Config: p.config, + Trigger: name, + Detail: detail, + }) + } + + return nil +} + // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS func (p *Postgres) Lock() error { return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error { @@ -254,9 +284,19 @@ func (p *Postgres) Run(migration io.Reader) error { if p.config.MultiStatementEnabled { var err error if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { + if e := p.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(m)}); e != nil { + return false + } if err = p.runStatement(m); err != nil { return false } + if e := p.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(m)}); e != nil { + return false + } return true }); e != nil { return e @@ -267,7 +307,20 @@ func (p *Postgres) Run(migration io.Reader) error { if err != nil { return err } - return p.runStatement(migr) + if err = p.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(migr)}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } + if err = p.runStatement(migr); err != nil { + return err + } + if err = p.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(migr)}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } + return nil } func (p *Postgres) runStatement(statement []byte) error { @@ -343,6 +396,16 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } + if err := p.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) if _, err := tx.Exec(query); err != nil { if errRollback := tx.Rollback(); errRollback != nil { @@ -364,6 +427,16 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { } } + if err := p.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } @@ -464,14 +537,25 @@ func (p *Postgres) ensureVersionTable() (err error) { } if count == 1 { + if err := p.Trigger(database.TrigVersionTableExists, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTableExists"} + } return nil } + if err := p.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)` if _, err = p.conn.ExecContext(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } + if err := p.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go index 5e4519115..453c39fbd 100644 --- a/database/postgres/postgres.go +++ b/database/postgres/postgres.go @@ -52,6 +52,8 @@ type Config struct { migrationsTableName string StatementTimeout time.Duration MultiStatementMaxSize int + + Triggers map[string]func(response interface{}) error } type Postgres struct { @@ -64,6 +66,13 @@ type Postgres struct { config *Config } +type TriggerResponse struct { + Driver *Postgres + Config *Config + Trigger string + Detail interface{} +} + func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Postgres, error) { if config == nil { return nil, ErrNilConfig @@ -230,6 +239,27 @@ func (p *Postgres) Close() error { return nil } +func (p *Postgres) AddTriggers(t map[string]func(response interface{}) error) { + p.config.Triggers = t +} + +func (p *Postgres) Trigger(name string, detail interface{}) error { + if p.config.Triggers == nil { + return nil + } + + if trigger, ok := p.config.Triggers[name]; ok { + return trigger(TriggerResponse{ + Driver: p, + Config: p.config, + Trigger: name, + Detail: detail, + }) + } + + return nil +} + // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS func (p *Postgres) Lock() error { return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error { @@ -267,9 +297,19 @@ func (p *Postgres) Run(migration io.Reader) error { if p.config.MultiStatementEnabled { var err error if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { + if e := p.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(m)}); e != nil { + return false + } if err = p.runStatement(m); err != nil { return false } + if e := p.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(m)}); e != nil { + return false + } return true }); e != nil { return e @@ -280,7 +320,20 @@ func (p *Postgres) Run(migration io.Reader) error { if err != nil { return err } - return p.runStatement(migr) + if err = p.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(migr)}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } + if err = p.runStatement(migr); err != nil { + return err + } + if err = p.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(migr)}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } + return nil } func (p *Postgres) runStatement(statement []byte) error { @@ -359,6 +412,16 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } + if err := p.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := `TRUNCATE ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) if _, err := tx.Exec(query); err != nil { if errRollback := tx.Rollback(); errRollback != nil { @@ -380,6 +443,16 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { } } + if err := p.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } @@ -480,13 +553,24 @@ func (p *Postgres) ensureVersionTable() (err error) { } if count == 1 { + if err := p.Trigger(database.TrigVersionTableExists, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTableExists"} + } return nil } + if err := p.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + query = `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)` if _, err = p.conn.ExecContext(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } + if err := p.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } diff --git a/database/ql/ql.go b/database/ql/ql.go index 37c062455..6c37651db 100644 --- a/database/ql/ql.go +++ b/database/ql/ql.go @@ -30,6 +30,8 @@ var ( type Config struct { MigrationsTable string DatabaseName string + + Triggers map[string]func(response interface{}) error } type Ql struct { @@ -39,6 +41,13 @@ type Ql struct { config *Config } +type TriggerResponse struct { + Driver *Ql + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -84,6 +93,12 @@ func (m *Ql) ensureVersionTable() (err error) { if err != nil { return err } + if err := m.Trigger(database.TrigVersionTablePre, nil); err != nil { + if err := tx.Rollback(); err != nil { + return err + } + return err + } if _, err := tx.Exec(fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s (version uint64, dirty bool); CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version); @@ -93,6 +108,12 @@ func (m *Ql) ensureVersionTable() (err error) { } return err } + if err := m.Trigger(database.TrigVersionTablePost, nil); err != nil { + if err := tx.Rollback(); err != nil { + return err + } + return err + } if err := tx.Commit(); err != nil { return err } @@ -125,6 +146,28 @@ func (m *Ql) Open(url string) (database.Driver, error) { func (m *Ql) Close() error { return m.db.Close() } + +func (m *Ql) AddTriggers(t map[string]func(response interface{}) error) { + m.config.Triggers = t +} + +func (m *Ql) Trigger(name string, detail interface{}) error { + if m.config.Triggers == nil { + return nil + } + + if trigger, ok := m.config.Triggers[name]; ok { + return trigger(TriggerResponse{ + Driver: m, + Config: m.config, + Trigger: name, + Detail: detail, + }) + } + + return nil +} + func (m *Ql) Drop() (err error) { query := `SELECT Name FROM __Table` tables, err := m.db.Query(query) @@ -184,7 +227,22 @@ func (m *Ql) Run(migration io.Reader) error { } query := string(migr[:]) - return m.executeQuery(query) + if err := m.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } + + if err = m.executeQuery(query); err != nil { + return err + } + + if err := m.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } + return nil } func (m *Ql) executeQuery(query string) error { tx, err := m.db.Begin() @@ -208,6 +266,16 @@ func (m *Ql) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } + if err := m.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := "TRUNCATE TABLE " + m.config.MigrationsTable if _, err := tx.Exec(query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} @@ -227,6 +295,16 @@ func (m *Ql) SetVersion(version int, dirty bool) error { } } + if err := m.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } diff --git a/database/redshift/redshift.go b/database/redshift/redshift.go index 7687b9d9a..cff6973e1 100644 --- a/database/redshift/redshift.go +++ b/database/redshift/redshift.go @@ -34,6 +34,8 @@ var ( type Config struct { MigrationsTable string DatabaseName string + + Triggers map[string]func(response interface{}) error } type Redshift struct { @@ -45,6 +47,13 @@ type Redshift struct { config *Config } +type TriggerResponse struct { + Driver *Redshift + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -125,6 +134,27 @@ func (p *Redshift) Close() error { return nil } +func (p *Redshift) AddTriggers(t map[string]func(response interface{}) error) { + p.config.Triggers = t +} + +func (p *Redshift) Trigger(name string, detail interface{}) error { + if p.config.Triggers == nil { + return nil + } + + if trigger, ok := p.config.Triggers[name]; ok { + return trigger(TriggerResponse{ + Driver: p, + Config: p.config, + Trigger: name, + Detail: detail, + }) + } + + return nil +} + // Redshift does not support advisory lock functions: https://docs.aws.amazon.com/redshift/latest/dg/c_unsupported-postgresql-functions.html func (p *Redshift) Lock() error { if !p.isLocked.CAS(false, true) { @@ -148,6 +178,11 @@ func (p *Redshift) Run(migration io.Reader) error { // run migration query := string(migr[:]) + if err := p.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } if _, err := p.conn.ExecContext(context.Background(), query); err != nil { if pgErr, ok := err.(*pq.Error); ok { var line uint @@ -169,6 +204,11 @@ func (p *Redshift) Run(migration io.Reader) error { } return database.Error{OrigErr: err, Err: "migration failed", Query: migr} } + if err := p.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } return nil } @@ -214,6 +254,16 @@ func (p *Redshift) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } + if err := p.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := `DELETE FROM "` + p.config.MigrationsTable + `"` if _, err := tx.Exec(query); err != nil { if errRollback := tx.Rollback(); errRollback != nil { @@ -235,6 +285,16 @@ func (p *Redshift) SetVersion(version int, dirty bool) error { } } + if err := p.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } @@ -328,13 +388,25 @@ func (p *Redshift) ensureVersionTable() (err error) { return &database.Error{OrigErr: err, Query: []byte(query)} } if count == 1 { + if err := p.Trigger(database.TrigVersionTableExists, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTableExists"} + } return nil } + if err := p.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + // if not, create the empty migration table query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)` if _, err := p.conn.ExecContext(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } + + if err := p.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } diff --git a/database/rqlite/rqlite.go b/database/rqlite/rqlite.go index af0d53007..fd3f633fa 100644 --- a/database/rqlite/rqlite.go +++ b/database/rqlite/rqlite.go @@ -39,6 +39,8 @@ type Config struct { ConnectInsecure bool // MigrationsTable configures the migrations table name MigrationsTable string + + Triggers map[string]func(response interface{}) error } type Rqlite struct { @@ -48,6 +50,13 @@ type Rqlite struct { config *Config } +type TriggerResponse struct { + Driver *Rqlite + Config *Config + Trigger string + Detail interface{} +} + // WithInstance creates a rqlite database driver with an existing gorqlite database connection // and a Config struct func WithInstance(instance *gorqlite.Connection, config *Config) (database.Driver, error) { @@ -97,6 +106,10 @@ func (r *Rqlite) ensureVersionTable() (err error) { } }() + if err := r.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + stmts := []string{ fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (version uint64, dirty bool)`, r.config.MigrationsTable), fmt.Sprintf(`CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version)`, r.config.MigrationsTable), @@ -106,6 +119,10 @@ func (r *Rqlite) ensureVersionTable() (err error) { return err } + if err := r.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } @@ -138,6 +155,27 @@ func (r *Rqlite) Close() error { return nil } +func (r *Rqlite) AddTriggers(t map[string]func(response interface{}) error) { + r.config.Triggers = t +} + +func (r *Rqlite) Trigger(name string, detail interface{}) error { + if r.config.Triggers == nil { + return nil + } + + if trigger, ok := r.config.Triggers[name]; ok { + return trigger(TriggerResponse{ + Driver: r, + Config: r.config, + Trigger: name, + Detail: detail, + }) + } + + return nil +} + // Lock should acquire a database lock so that only one migration process // can run at a time. Migrate will call this function before Run is called. // If the implementation can't provide this functionality, return nil. @@ -166,9 +204,19 @@ func (r *Rqlite) Run(migration io.Reader) error { } query := string(migr[:]) + if err := r.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } if _, err := r.db.WriteOne(query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } + if err := r.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } return nil } @@ -177,6 +225,13 @@ func (r *Rqlite) Run(migration io.Reader) error { // Migrate will call this function before and after each call to Run. // version must be >= -1. -1 means NilVersion. func (r *Rqlite) SetVersion(version int, dirty bool) error { + if err := r.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + deleteQuery := fmt.Sprintf(`DELETE FROM %s`, r.config.MigrationsTable) statements := []gorqlite.ParameterizedStatement{ { @@ -210,6 +265,13 @@ func (r *Rqlite) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Query: []byte(deleteQuery + "\n" + insertQuery)} } + if err := r.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + return nil } diff --git a/database/snowflake/snowflake.go b/database/snowflake/snowflake.go index 46ce30200..9ea573da1 100644 --- a/database/snowflake/snowflake.go +++ b/database/snowflake/snowflake.go @@ -35,6 +35,8 @@ var ( type Config struct { MigrationsTable string DatabaseName string + + Triggers map[string]func(response interface{}) error } type Snowflake struct { @@ -46,6 +48,13 @@ type Snowflake struct { config *Config } +type TriggerResponse struct { + Driver *Snowflake + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -158,6 +167,27 @@ func (p *Snowflake) Close() error { return nil } +func (p *Snowflake) AddTriggers(t map[string]func(response interface{}) error) { + p.config.Triggers = t +} + +func (p *Snowflake) Trigger(name string, detail interface{}) error { + if p.config.Triggers == nil { + return nil + } + + if trigger, ok := p.config.Triggers[name]; ok { + return trigger(TriggerResponse{ + Driver: p, + Config: p.config, + Trigger: name, + Detail: detail, + }) + } + + return nil +} + func (p *Snowflake) Lock() error { if !p.isLocked.CAS(false, true) { return database.ErrLocked @@ -180,6 +210,11 @@ func (p *Snowflake) Run(migration io.Reader) error { // run migration query := string(migr[:]) + if err := p.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } if _, err := p.conn.ExecContext(context.Background(), query); err != nil { if pgErr, ok := err.(*pq.Error); ok { var line uint @@ -201,6 +236,11 @@ func (p *Snowflake) Run(migration io.Reader) error { } return database.Error{OrigErr: err, Err: "migration failed", Query: migr} } + if err := p.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } return nil } @@ -246,6 +286,16 @@ func (p *Snowflake) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } + if err := p.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := `DELETE FROM "` + p.config.MigrationsTable + `"` if _, err := tx.Exec(query); err != nil { if errRollback := tx.Rollback(); errRollback != nil { @@ -269,6 +319,16 @@ func (p *Snowflake) SetVersion(version int, dirty bool) error { } } + if err := p.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } @@ -362,9 +422,16 @@ func (p *Snowflake) ensureVersionTable() (err error) { return &database.Error{OrigErr: err, Query: []byte(query)} } if count == 1 { + if err := p.Trigger(database.TrigVersionTableExists, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTableExists"} + } return nil } + if err := p.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + // if not, create the empty migration table query = `CREATE TABLE if not exists "` + p.config.MigrationsTable + `" ( version bigint not null primary key, dirty boolean not null)` @@ -372,5 +439,9 @@ func (p *Snowflake) ensureVersionTable() (err error) { return &database.Error{OrigErr: err, Query: []byte(query)} } + if err := p.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } diff --git a/database/spanner/spanner.go b/database/spanner/spanner.go index b733302d5..b4071b98b 100644 --- a/database/spanner/spanner.go +++ b/database/spanner/spanner.go @@ -56,6 +56,8 @@ type Config struct { // Parsing outputs clean DDL statements such as reformatted // and void of comments. CleanStatements bool + + Triggers map[string]func(response interface{}) error } // Spanner implements database.Driver for Google Cloud Spanner @@ -72,6 +74,13 @@ type DB struct { data *spanner.Client } +type TriggerResponse struct { + Driver *Spanner + Config *Config + Trigger string + Detail interface{} +} + func NewDB(admin sdb.DatabaseAdminClient, data spanner.Client) *DB { return &DB{ admin: &admin, @@ -150,6 +159,27 @@ func (s *Spanner) Close() error { return s.db.admin.Close() } +func (s *Spanner) AddTriggers(t map[string]func(response interface{}) error) { + s.config.Triggers = t +} + +func (s *Spanner) Trigger(name string, detail interface{}) error { + if s.config.Triggers == nil { + return nil + } + + if trigger, ok := s.config.Triggers[name]; ok { + return trigger(TriggerResponse{ + Driver: s, + Config: s.config, + Trigger: name, + Detail: detail, + }) + } + + return nil +} + // Lock implements database.Driver but doesn't do anything because Spanner only // enqueues the UpdateDatabaseDdlRequest. func (s *Spanner) Lock() error { @@ -174,6 +204,12 @@ func (s *Spanner) Run(migration io.Reader) error { return err } + if err := s.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(migr)}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } + stmts := []string{string(migr)} if s.config.CleanStatements { stmts, err = cleanStatements(migr) @@ -196,6 +232,12 @@ func (s *Spanner) Run(migration io.Reader) error { return &database.Error{OrigErr: err, Err: "migration failed", Query: migr} } + if err := s.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(migr)}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } + return nil } @@ -203,6 +245,13 @@ func (s *Spanner) Run(migration io.Reader) error { func (s *Spanner) SetVersion(version int, dirty bool) error { ctx := context.Background() + if err := s.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + _, err := s.db.data.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { m := []*spanner.Mutation{ @@ -217,6 +266,13 @@ func (s *Spanner) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err} } + if err := s.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + return nil } @@ -317,9 +373,16 @@ func (s *Spanner) ensureVersionTable() (err error) { tbl := s.config.MigrationsTable iter := s.db.data.Single().Read(ctx, tbl, spanner.AllKeys(), []string{"Version"}) if err := iter.Do(func(r *spanner.Row) error { return nil }); err == nil { + if err := s.Trigger(database.TrigVersionTableExists, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTableExists"} + } return nil } + if err := s.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + stmt := fmt.Sprintf(`CREATE TABLE %s ( Version INT64 NOT NULL, Dirty BOOL NOT NULL @@ -337,6 +400,10 @@ func (s *Spanner) ensureVersionTable() (err error) { return &database.Error{OrigErr: err, Query: []byte(stmt)} } + if err := s.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } diff --git a/database/sqlcipher/sqlcipher.go b/database/sqlcipher/sqlcipher.go index f98fb3a21..70c4f1aea 100644 --- a/database/sqlcipher/sqlcipher.go +++ b/database/sqlcipher/sqlcipher.go @@ -31,6 +31,8 @@ type Config struct { MigrationsTable string DatabaseName string NoTxWrap bool + + Triggers map[string]func(response interface{}) error } type Sqlite struct { @@ -40,6 +42,13 @@ type Sqlite struct { config *Config } +type TriggerResponse struct { + Driver *Sqlite + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -81,6 +90,10 @@ func (m *Sqlite) ensureVersionTable() (err error) { } }() + if err := m.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + query := fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool); CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version); @@ -89,6 +102,11 @@ func (m *Sqlite) ensureVersionTable() (err error) { if _, err := m.db.Exec(query); err != nil { return err } + + if err := m.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } @@ -133,6 +151,27 @@ func (m *Sqlite) Close() error { return m.db.Close() } +func (m *Sqlite) AddTriggers(t map[string]func(response interface{}) error) { + m.config.Triggers = t +} + +func (m *Sqlite) Trigger(name string, detail interface{}) error { + if m.config.Triggers == nil { + return nil + } + + if trigger, ok := m.config.Triggers[name]; ok { + return trigger(TriggerResponse{ + Driver: m, + Config: m.config, + Trigger: name, + Detail: detail, + }) + } + + return nil +} + func (m *Sqlite) Drop() (err error) { query := `SELECT name FROM sqlite_master WHERE type = 'table';` tables, err := m.db.Query(query) @@ -198,10 +237,29 @@ func (m *Sqlite) Run(migration io.Reader) error { } query := string(migr[:]) + if err := m.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } + if m.config.NoTxWrap { - return m.executeQueryNoTx(query) + if err = m.executeQueryNoTx(query); err != nil { + return err + } + } else { + if err = m.executeQuery(query); err != nil { + return err + } } - return m.executeQuery(query) + + if err := m.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } + + return nil } func (m *Sqlite) executeQuery(query string) error { @@ -234,6 +292,16 @@ func (m *Sqlite) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } + if err := m.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := "DELETE FROM " + m.config.MigrationsTable if _, err := tx.Exec(query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} @@ -252,6 +320,16 @@ func (m *Sqlite) SetVersion(version int, dirty bool) error { } } + if err := m.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } diff --git a/database/sqlite/sqlite.go b/database/sqlite/sqlite.go index ce449dfa0..58b2b1966 100644 --- a/database/sqlite/sqlite.go +++ b/database/sqlite/sqlite.go @@ -31,6 +31,8 @@ type Config struct { MigrationsTable string DatabaseName string NoTxWrap bool + + Triggers map[string]func(response interface{}) error } type Sqlite struct { @@ -40,6 +42,13 @@ type Sqlite struct { config *Config } +type TriggerResponse struct { + Driver *Sqlite + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -81,6 +90,10 @@ func (m *Sqlite) ensureVersionTable() (err error) { } }() + if err := m.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + query := fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool); CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version); @@ -89,6 +102,10 @@ func (m *Sqlite) ensureVersionTable() (err error) { if _, err := m.db.Exec(query); err != nil { return err } + + if err := m.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } return nil } @@ -133,6 +150,27 @@ func (m *Sqlite) Close() error { return m.db.Close() } +func (m *Sqlite) AddTriggers(t map[string]func(response interface{}) error) { + m.config.Triggers = t +} + +func (m *Sqlite) Trigger(name string, detail interface{}) error { + if m.config.Triggers == nil { + return nil + } + + if trigger, ok := m.config.Triggers[name]; ok { + return trigger(TriggerResponse{ + Driver: m, + Config: m.config, + Trigger: name, + Detail: detail, + }) + } + + return nil +} + func (m *Sqlite) Drop() (err error) { query := `SELECT name FROM sqlite_master WHERE type = 'table';` tables, err := m.db.Query(query) @@ -198,10 +236,26 @@ func (m *Sqlite) Run(migration io.Reader) error { } query := string(migr[:]) + if err := m.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } + if m.config.NoTxWrap { - return m.executeQueryNoTx(query) + if err = m.executeQueryNoTx(query); err != nil { + return err + } + } else if err = m.executeQuery(query); err != nil { + return err } - return m.executeQuery(query) + + if err := m.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } + return nil } func (m *Sqlite) executeQuery(query string) error { @@ -234,6 +288,16 @@ func (m *Sqlite) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } + if err := m.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := "DELETE FROM " + m.config.MigrationsTable if _, err := tx.Exec(query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} @@ -252,6 +316,16 @@ func (m *Sqlite) SetVersion(version int, dirty bool) error { } } + if err := m.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } diff --git a/database/sqlite3/sqlite3.go b/database/sqlite3/sqlite3.go index 56bb23338..ab449555e 100644 --- a/database/sqlite3/sqlite3.go +++ b/database/sqlite3/sqlite3.go @@ -31,6 +31,8 @@ type Config struct { MigrationsTable string DatabaseName string NoTxWrap bool + + Triggers map[string]func(response interface{}) error } type Sqlite struct { @@ -40,6 +42,13 @@ type Sqlite struct { config *Config } +type TriggerResponse struct { + Driver *Sqlite + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -81,6 +90,10 @@ func (m *Sqlite) ensureVersionTable() (err error) { } }() + if err := m.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + query := fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool); CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version); @@ -89,6 +102,10 @@ func (m *Sqlite) ensureVersionTable() (err error) { if _, err := m.db.Exec(query); err != nil { return err } + + if err := m.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } return nil } @@ -133,6 +150,27 @@ func (m *Sqlite) Close() error { return m.db.Close() } +func (m *Sqlite) AddTriggers(t map[string]func(response interface{}) error) { + m.config.Triggers = t +} + +func (m *Sqlite) Trigger(name string, detail interface{}) error { + if m.config.Triggers == nil { + return nil + } + + if trigger, ok := m.config.Triggers[name]; ok { + return trigger(TriggerResponse{ + Driver: m, + Config: m.config, + Trigger: name, + Detail: detail, + }) + } + + return nil +} + func (m *Sqlite) Drop() (err error) { query := `SELECT name FROM sqlite_master WHERE type = 'table';` tables, err := m.db.Query(query) @@ -198,10 +236,26 @@ func (m *Sqlite) Run(migration io.Reader) error { } query := string(migr[:]) + if err := m.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } + if m.config.NoTxWrap { - return m.executeQueryNoTx(query) + if err = m.executeQueryNoTx(query); err != nil { + return err + } + } else if err = m.executeQuery(query); err != nil { + return err } - return m.executeQuery(query) + + if err := m.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } + return nil } func (m *Sqlite) executeQuery(query string) error { @@ -234,6 +288,16 @@ func (m *Sqlite) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } + if err := m.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := "DELETE FROM " + m.config.MigrationsTable if _, err := tx.Exec(query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} @@ -252,6 +316,16 @@ func (m *Sqlite) SetVersion(version int, dirty bool) error { } } + if err := m.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } diff --git a/database/sqlserver/sqlserver.go b/database/sqlserver/sqlserver.go index 92834d1ad..9544c86e4 100644 --- a/database/sqlserver/sqlserver.go +++ b/database/sqlserver/sqlserver.go @@ -45,6 +45,8 @@ type Config struct { MigrationsTable string DatabaseName string SchemaName string + + Triggers map[string]func(response interface{}) error } // SQL Server connection @@ -58,6 +60,13 @@ type SQLServer struct { config *Config } +type TriggerResponse struct { + Driver *SQLServer + Config *Config + Trigger string + Detail interface{} +} + // WithInstance returns a database instance from an already created database connection. // // Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver. @@ -190,6 +199,27 @@ func (ss *SQLServer) Close() error { return nil } +func (ss *SQLServer) AddTriggers(t map[string]func(response interface{}) error) { + ss.config.Triggers = t +} + +func (ss *SQLServer) Trigger(name string, detail interface{}) error { + if ss.config.Triggers == nil { + return nil + } + + if trigger, ok := ss.config.Triggers[name]; ok { + return trigger(TriggerResponse{ + Driver: ss, + Config: ss.config, + Trigger: name, + Detail: detail, + }) + } + + return nil +} + // Lock creates an advisory local on the database to prevent multiple migrations from running at the same time. func (ss *SQLServer) Lock() error { return database.CasRestoreOnErr(&ss.isLocked, false, true, database.ErrLocked, func() error { @@ -247,6 +277,11 @@ func (ss *SQLServer) Run(migration io.Reader) error { // run migration query := string(migr[:]) + if err := ss.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } if _, err := ss.conn.ExecContext(context.Background(), query); err != nil { if msErr, ok := err.(mssql.Error); ok { message := fmt.Sprintf("migration failed: %s", msErr.Message) @@ -257,6 +292,11 @@ func (ss *SQLServer) Run(migration io.Reader) error { } return database.Error{OrigErr: err, Err: "migration failed", Query: migr} } + if err := ss.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } return nil } @@ -269,6 +309,16 @@ func (ss *SQLServer) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } + if err := ss.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := `TRUNCATE TABLE ` + ss.getMigrationTable() if _, err := tx.Exec(query); err != nil { if errRollback := tx.Rollback(); errRollback != nil { @@ -294,6 +344,16 @@ func (ss *SQLServer) SetVersion(version int, dirty bool) error { } } + if err := ss.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } @@ -368,6 +428,10 @@ func (ss *SQLServer) ensureVersionTable() (err error) { } }() + if err := ss.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + query := `IF NOT EXISTS (SELECT * FROM sysobjects @@ -380,6 +444,10 @@ func (ss *SQLServer) ensureVersionTable() (err error) { return &database.Error{OrigErr: err, Query: []byte(query)} } + if err := ss.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } diff --git a/database/stub/stub.go b/database/stub/stub.go index ae502650b..b28017fb0 100644 --- a/database/stub/stub.go +++ b/database/stub/stub.go @@ -34,7 +34,16 @@ func (s *Stub) Open(url string) (database.Driver, error) { }, nil } -type Config struct{} +type Config struct { + Triggers map[string]func(response interface{}) error +} + +type TriggerResponse struct { + Driver *Stub + Config *Config + Trigger string + Detail interface{} +} func WithInstance(instance interface{}, config *Config) (database.Driver, error) { return &Stub{ @@ -49,6 +58,27 @@ func (s *Stub) Close() error { return nil } +func (s *Stub) AddTriggers(t map[string]func(response interface{}) error) { + s.Config.Triggers = t +} + +func (s *Stub) Trigger(name string, detail interface{}) error { + if s.Config.Triggers == nil { + return nil + } + + if trigger, ok := s.Config.Triggers[name]; ok { + return trigger(TriggerResponse{ + Driver: s, + Config: s.Config, + Trigger: name, + Detail: detail, + }) + } + + return nil +} + func (s *Stub) Lock() error { if !s.isLocked.CAS(false, true) { return database.ErrLocked @@ -68,8 +98,22 @@ func (s *Stub) Run(migration io.Reader) error { if err != nil { return err } + + if err := s.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(m[:])}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } + s.LastRunMigration = m s.MigrationSequence = append(s.MigrationSequence, string(m[:])) + + if err := s.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(m[:])}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } + return nil } diff --git a/database/testing/migrate_testing.go b/database/testing/migrate_testing.go index be8ed195f..43d0cd14e 100644 --- a/database/testing/migrate_testing.go +++ b/database/testing/migrate_testing.go @@ -4,6 +4,8 @@ package testing import ( + "github.com/golang-migrate/migrate/v4/database" + "reflect" "testing" ) @@ -28,7 +30,74 @@ func TestMigrateDrop(t *testing.T, m *migrate.Migrate) { func TestMigrateUp(t *testing.T, m *migrate.Migrate) { t.Log("UP") + + tt := &triggerTest{ + t: t, + m: m, + triggered: map[string]bool{ + migrate.TrigRunMigrationPre: false, + migrate.TrigRunMigrationPost: false, + migrate.TrigRunMigrationVersionPre: false, + migrate.TrigRunMigrationVersionPost: false, + database.TrigRunPre: false, + database.TrigRunPost: false, + }, + } + + m.Triggers = map[string]func(r migrate.TriggerResponse) error{ + migrate.TrigRunMigrationPre: tt.trigMigrationCheck, + migrate.TrigRunMigrationPost: tt.trigMigrationCheck, + migrate.TrigRunMigrationVersionPre: tt.trigMigrationCheck, + migrate.TrigRunMigrationVersionPost: tt.trigMigrationCheck, + } + + m.AddDatabaseTriggers(map[string]func(response interface{}) error{ + database.TrigRunPre: tt.trigDatabaseMigrationCheck, + database.TrigRunPost: tt.trigDatabaseMigrationCheck, + }) + if err := m.Up(); err != nil { t.Fatal(err) } + + if !tt.triggered[migrate.TrigRunMigrationPre] { + t.Fatalf("expected trigger %s to be called, but it was not", migrate.TrigRunMigrationPre) + } + if !tt.triggered[migrate.TrigRunMigrationPost] { + t.Fatalf("expected trigger %s to be called, but it was not", migrate.TrigRunMigrationPost) + } + if !tt.triggered[migrate.TrigRunMigrationVersionPre] { + t.Fatalf("expected trigger %s to be called, but it was not", migrate.TrigRunMigrationVersionPre) + } + if !tt.triggered[migrate.TrigRunMigrationVersionPost] { + t.Fatalf("expected trigger %s to be called, but it was not", migrate.TrigRunMigrationVersionPost) + } + if !tt.triggered[database.TrigRunPre] { + t.Fatalf("expected database trigger %s to be called, but it was not", database.TrigRunPre) + } + if !tt.triggered[database.TrigRunPost] { + t.Fatalf("expected database trigger %s to be called, but it was not", database.TrigRunPost) + } +} + +type triggerTest struct { + t *testing.T + m *migrate.Migrate + triggered map[string]bool +} + +func (tt *triggerTest) trigMigrationCheck(r migrate.TriggerResponse) error { + tt.triggered[r.Trigger] = true + return nil +} + +func (tt *triggerTest) trigDatabaseMigrationCheck(response interface{}) error { + val := reflect.ValueOf(response) + field := val.FieldByName("Trigger") + if !field.IsValid() { + tt.t.Fatalf("expected response to have a Trigger field, got %T", response) + } + + tt.triggered[field.String()] = true + return nil } diff --git a/database/yugabytedb/yugabytedb.go b/database/yugabytedb/yugabytedb.go index 764d23c02..3a507292c 100644 --- a/database/yugabytedb/yugabytedb.go +++ b/database/yugabytedb/yugabytedb.go @@ -49,6 +49,8 @@ type Config struct { MaxRetryInterval time.Duration MaxRetryElapsedTime time.Duration MaxRetries int + + Triggers map[string]func(response interface{}) error } type YugabyteDB struct { @@ -59,6 +61,13 @@ type YugabyteDB struct { config *Config } +type TriggerResponse struct { + Driver *YugabyteDB + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -189,6 +198,27 @@ func (c *YugabyteDB) Close() error { return c.db.Close() } +func (c *YugabyteDB) AddTriggers(t map[string]func(response interface{}) error) { + c.config.Triggers = t +} + +func (c *YugabyteDB) Trigger(name string, detail interface{}) error { + if c.config.Triggers == nil { + return nil + } + + if trigger, ok := c.config.Triggers[name]; ok { + return trigger(TriggerResponse{ + Driver: c, + Config: c.config, + Trigger: name, + Detail: detail, + }) + } + + return nil +} + // Locking is done manually with a separate lock table. Implementing advisory locks in YugabyteDB is being discussed // See: https://github.com/yugabyte/yugabyte-db/issues/3642 func (c *YugabyteDB) Lock() error { @@ -263,15 +293,32 @@ func (c *YugabyteDB) Run(migration io.Reader) error { // run migration query := string(migr[:]) + if err := c.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } if _, err := c.db.Exec(query); err != nil { return database.Error{OrigErr: err, Err: "migration failed", Query: migr} } + if err := c.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } return nil } func (c *YugabyteDB) SetVersion(version int, dirty bool) error { return c.doTxWithRetry(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable}, func(tx *sql.Tx) error { + if err := c.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + if _, err := tx.Exec(`DELETE FROM "` + c.config.MigrationsTable + `"`); err != nil { return err } @@ -285,6 +332,13 @@ func (c *YugabyteDB) SetVersion(version int, dirty bool) error { } } + if err := c.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + return nil }) } @@ -375,14 +429,25 @@ func (c *YugabyteDB) ensureVersionTable() (err error) { return &database.Error{OrigErr: err, Query: []byte(query)} } if count == 1 { + if err := c.Trigger(database.TrigVersionTableExists, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTableExists"} + } return nil } + if err := c.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + // if not, create the empty migration table query = `CREATE TABLE "` + c.config.MigrationsTable + `" (version INT NOT NULL PRIMARY KEY, dirty BOOL NOT NULL)` if _, err := c.db.Exec(query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } + + if err := c.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } return nil } diff --git a/migrate.go b/migrate.go index 266cc04eb..948e775c3 100644 --- a/migrate.go +++ b/migrate.go @@ -36,6 +36,11 @@ var ( ErrLockTimeout = errors.New("timeout: can't acquire database lock") ) +const TrigRunMigrationPre = "RunMigrationPre" +const TrigRunMigrationPost = "RunMigrationPost" +const TrigRunMigrationVersionPre = "RunMigrationVersionPre" +const TrigRunMigrationVersionPost = "RunMigrationVersionPost" + // ErrShortLimit is an error returned when not enough migrations // can be returned by a source for a given limit. type ErrShortLimit struct { @@ -80,64 +85,108 @@ type Migrate struct { // LockTimeout defaults to DefaultLockTimeout, // but can be set per Migrate instance. LockTimeout time.Duration + + Triggers map[string]func(response TriggerResponse) error } -// New returns a new Migrate instance from a source URL and a database URL. -// The URL scheme is defined by each driver. -func New(sourceURL, databaseURL string) (*Migrate, error) { +type Options struct { + // Source from URL + // The URL scheme is defined by each driver. + SourceURL string + + // Source from Instance + // Use any string that can serve as an identifier during logging as sourceName. + // You are responsible for closing down the underlying source if necessary. + SourceName string + SourceInstance source.Driver + + // Database from URL + // The URL scheme is defined by each driver. + DatabaseURL string + + // Database from Instance + // Use any string that can serve as an identifier during logging as databaseName. + // You are responsible for closing the underlying database client if necessary. + // You will also need to setup your own triggers if needed. + DatabaseName string + DatabaseInstance database.Driver + + // Triggers - these can be used to execute arbitrary code to meet any additional + // requirements that may be needed (i.e. some people need a history of migrations) + MigrateTriggers map[string]func(response TriggerResponse) error + DatabaseTriggers map[string]func(response interface{}) error +} + +type TriggerResponse struct { + Trigger string + Detail interface{} +} + +// NewFromOptions returns a new Migrate instance from the options provided. +func NewFromOptions(o Options) (*Migrate, error) { m := newCommon() - sourceName, err := iurl.SchemeFromURL(sourceURL) - if err != nil { - return nil, fmt.Errorf("failed to parse scheme from source URL: %w", err) - } - m.sourceName = sourceName + if o.SourceURL != "" { + sourceName, err := iurl.SchemeFromURL(o.SourceURL) + if err != nil { + return nil, fmt.Errorf("failed to parse scheme from source URL: %w", err) + } + m.sourceName = sourceName - databaseName, err := iurl.SchemeFromURL(databaseURL) - if err != nil { - return nil, fmt.Errorf("failed to parse scheme from database URL: %w", err) + sourceDrv, err := source.Open(o.SourceURL) + if err != nil { + return nil, fmt.Errorf("failed to open source, %q: %w", o.SourceURL, err) + } + m.sourceDrv = sourceDrv + } else if o.SourceName != "" && o.SourceInstance != nil { + m.sourceName = o.SourceName + m.sourceDrv = o.SourceInstance + } else { + return nil, fmt.Errorf("must specify either SourceURL or SourceName and SourceInstance") } - m.databaseName = databaseName - sourceDrv, err := source.Open(sourceURL) - if err != nil { - return nil, fmt.Errorf("failed to open source, %q: %w", sourceURL, err) - } - m.sourceDrv = sourceDrv + if o.DatabaseURL != "" { + databaseName, err := iurl.SchemeFromURL(o.DatabaseURL) + if err != nil { + return nil, fmt.Errorf("failed to parse scheme from database URL: %w", err) + } + m.databaseName = databaseName - databaseDrv, err := database.Open(databaseURL) - if err != nil { - return nil, fmt.Errorf("failed to open database: %w", err) + databaseDrv, err := database.Open(o.DatabaseURL) + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + m.databaseDrv = databaseDrv + m.databaseDrv.AddTriggers(o.DatabaseTriggers) + } else if o.DatabaseInstance != nil { + m.databaseName = o.DatabaseName + m.databaseDrv = o.DatabaseInstance } - m.databaseDrv = databaseDrv + + m.Triggers = o.MigrateTriggers return m, nil } +// New returns a new Migrate instance from a source URL and a database URL. +// The URL scheme is defined by each driver. +func New(sourceURL, databaseURL string) (*Migrate, error) { + return NewFromOptions(Options{ + SourceURL: sourceURL, + DatabaseURL: databaseURL, + }) +} + // NewWithDatabaseInstance returns a new Migrate instance from a source URL // and an existing database instance. The source URL scheme is defined by each driver. // Use any string that can serve as an identifier during logging as databaseName. // You are responsible for closing the underlying database client if necessary. func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInstance database.Driver) (*Migrate, error) { - m := newCommon() - - sourceName, err := iurl.SchemeFromURL(sourceURL) - if err != nil { - return nil, err - } - m.sourceName = sourceName - - m.databaseName = databaseName - - sourceDrv, err := source.Open(sourceURL) - if err != nil { - return nil, fmt.Errorf("failed to open source, %q: %w", sourceURL, err) - } - m.sourceDrv = sourceDrv - - m.databaseDrv = databaseInstance - - return m, nil + return NewFromOptions(Options{ + SourceURL: sourceURL, + DatabaseName: databaseName, + DatabaseInstance: databaseInstance, + }) } // NewWithSourceInstance returns a new Migrate instance from an existing source instance @@ -145,25 +194,11 @@ func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInst // Use any string that can serve as an identifier during logging as sourceName. // You are responsible for closing the underlying source client if necessary. func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, databaseURL string) (*Migrate, error) { - m := newCommon() - - databaseName, err := iurl.SchemeFromURL(databaseURL) - if err != nil { - return nil, fmt.Errorf("failed to parse scheme from database URL: %w", err) - } - m.databaseName = databaseName - - m.sourceName = sourceName - - databaseDrv, err := database.Open(databaseURL) - if err != nil { - return nil, fmt.Errorf("failed to open database: %w", err) - } - m.databaseDrv = databaseDrv - - m.sourceDrv = sourceInstance - - return m, nil + return NewFromOptions(Options{ + SourceName: sourceName, + SourceInstance: sourceInstance, + DatabaseURL: databaseURL, + }) } // NewWithInstance returns a new Migrate instance from an existing source and @@ -171,15 +206,12 @@ func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, data // as sourceName and databaseName. You are responsible for closing down // the underlying source and database client if necessary. func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseName string, databaseInstance database.Driver) (*Migrate, error) { - m := newCommon() - - m.sourceName = sourceName - m.databaseName = databaseName - - m.sourceDrv = sourceInstance - m.databaseDrv = databaseInstance - - return m, nil + return NewFromOptions(Options{ + SourceName: sourceName, + SourceInstance: sourceInstance, + DatabaseName: databaseName, + DatabaseInstance: databaseInstance, + }) } func newCommon() *Migrate { @@ -191,6 +223,25 @@ func newCommon() *Migrate { } } +func (m *Migrate) Trigger(name string, detail interface{}) error { + if m.Triggers == nil { + return nil + } + + if trigger, ok := m.Triggers[name]; ok { + return trigger(TriggerResponse{ + Trigger: name, + Detail: detail, + }) + } + + return nil +} + +func (m *Migrate) AddDatabaseTriggers(t map[string]func(response interface{}) error) { + m.databaseDrv.AddTriggers(t) +} + // Close closes the source and the database. func (m *Migrate) Close() (source error, database error) { databaseSrvClose := make(chan error) @@ -723,6 +774,10 @@ func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) { // to stop execution because it might have received a stop signal on the // GracefulStop channel. func (m *Migrate) runMigrations(ret <-chan interface{}) error { + if err := m.Trigger(TrigRunMigrationPre, nil); err != nil { + return err + } + for r := range ret { if m.stop() { @@ -742,10 +797,18 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error { } if migr.Body != nil { + if err := m.Trigger(TrigRunMigrationVersionPre, struct{ Migration *Migration }{migr}); err != nil { + return err + } + m.logVerbosePrintf("Read and execute %v\n", migr.LogString()) if err := m.databaseDrv.Run(migr.BufferedBody); err != nil { return err } + + if err := m.Trigger(TrigRunMigrationVersionPost, struct{ Migration *Migration }{migr}); err != nil { + return err + } } // set clean state @@ -770,6 +833,11 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error { return fmt.Errorf("unknown type: %T with value: %+v", r, r) } } + + if err := m.Trigger(TrigRunMigrationPost, nil); err != nil { + return err + } + return nil }