Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions database/cassandra/cassandra.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ type Config struct {
KeyspaceName string
MultiStatementEnabled bool
MultiStatementMaxSize int

Triggers map[string]func(response interface{}) error
}

type TriggerResponse struct {
Driver *Cassandra
Config *Config
Trigger string
Detail interface{}
}

type Cassandra struct {
Expand Down Expand Up @@ -198,6 +207,27 @@ func (c *Cassandra) Close() error {
return nil
}

func (c *Cassandra) AddTriggers(t map[string]func(response interface{}) error) {
c.config.Triggers = t
}

func (c *Cassandra) Trigger(name string, detail interface{}) error {
if c.config.Triggers == nil {
return nil
}

if trigger, ok := c.config.Triggers[name]; ok {
return trigger(TriggerResponse{
Driver: c,
Config: c.config,
Trigger: name,
Detail: detail,
})
}

return nil
}

func (c *Cassandra) Lock() error {
if !c.isLocked.CAS(false, true) {
return database.ErrLocked
Expand All @@ -220,10 +250,22 @@ func (c *Cassandra) Run(migration io.Reader) error {
if tq == "" {
return true
}
if e := c.Trigger(database.TrigRunPre, struct {
Query string
}{Query: tq}); e != nil {
err = database.Error{OrigErr: e, Err: "failed to trigger RunPre"}
return false
}
if e := c.session.Query(tq).Exec(); e != nil {
err = database.Error{OrigErr: e, Err: "migration failed", Query: m}
return false
}
if e := c.Trigger(database.TrigRunPost, struct {
Query string
}{Query: tq}); e != nil {
err = database.Error{OrigErr: e, Err: "failed to trigger RunPost"}
return false
}
return true
}); e != nil {
return e
Expand All @@ -235,15 +277,32 @@ func (c *Cassandra) Run(migration io.Reader) error {
if err != nil {
return err
}
if err := c.Trigger(database.TrigRunPre, struct {
Query string
}{Query: string(migr)}); err != nil {
return database.Error{OrigErr: err, Err: "failed to trigger RunPre"}
}
// run migration
if err := c.session.Query(string(migr)).Exec(); err != nil {
// TODO: cast to Cassandra error and get line number
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
}
if err := c.Trigger(database.TrigRunPost, struct {
Query string
}{Query: string(migr)}); err != nil {
return database.Error{OrigErr: err, Err: "failed to trigger RunPost"}
}
return nil
}

func (c *Cassandra) SetVersion(version int, dirty bool) error {
if err := c.Trigger(database.TrigSetVersionPre, struct {
Version int
Dirty bool
}{Version: version, Dirty: dirty}); err != nil {
return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"}
}

// DELETE instead of TRUNCATE because AWS Keyspaces does not support it
// see: https://docs.aws.amazon.com/keyspaces/latest/devguide/cassandra-apis.html
squery := `SELECT version FROM "` + c.config.MigrationsTable + `"`
Expand All @@ -269,6 +328,13 @@ func (c *Cassandra) SetVersion(version int, dirty bool) error {
}
}

if err := c.Trigger(database.TrigSetVersionPost, struct {
Version int
Dirty bool
}{Version: version, Dirty: dirty}); err != nil {
return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"}
}

return nil
}

Expand Down Expand Up @@ -324,13 +390,22 @@ func (c *Cassandra) ensureVersionTable() (err error) {
}
}()

if err := c.Trigger(database.TrigVersionTablePre, nil); err != nil {
return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"}
}

err = c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec()
if err != nil {
return err
}
if _, _, err = c.Version(); err != nil {
return err
}

if err := c.Trigger(database.TrigVersionTablePost, nil); err != nil {
return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"}
}

return nil
}

Expand Down
84 changes: 82 additions & 2 deletions database/clickhouse/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ type Config struct {
MigrationsTableEngine string
MultiStatementEnabled bool
MultiStatementMaxSize int

Triggers map[string]func(response interface{}) error
}

func init() {
Expand Down Expand Up @@ -67,6 +69,13 @@ type ClickHouse struct {
isLocked atomic.Bool
}

type TriggerResponse struct {
Driver *ClickHouse
Config *Config
Trigger string
Detail interface{}
}

func (ch *ClickHouse) Open(dsn string) (database.Driver, error) {
purl, err := url.Parse(dsn)
if err != nil {
Expand Down Expand Up @@ -141,10 +150,22 @@ func (ch *ClickHouse) Run(r io.Reader) error {
if tq == "" {
return true
}
if e := ch.Trigger(database.TrigRunPre, struct {
Query string
}{Query: tq}); e != nil {
err = database.Error{OrigErr: e, Err: "failed to trigger RunPre"}
return false
}
if _, e := ch.conn.Exec(string(m)); e != nil {
err = database.Error{OrigErr: e, Err: "migration failed", Query: m}
return false
}
if e := ch.Trigger(database.TrigRunPost, struct {
Query string
}{Query: tq}); e != nil {
err = database.Error{OrigErr: e, Err: "failed to trigger RunPost"}
return false
}
return true
}); e != nil {
return e
Expand All @@ -157,10 +178,22 @@ func (ch *ClickHouse) Run(r io.Reader) error {
return err
}

if err := ch.Trigger(database.TrigRunPre, struct {
Query string
}{Query: string(migration)}); err != nil {
return database.Error{OrigErr: err, Err: "failed to trigger RunPre"}
}

if _, err := ch.conn.Exec(string(migration)); err != nil {
return database.Error{OrigErr: err, Err: "migration failed", Query: migration}
}

if err := ch.Trigger(database.TrigRunPost, struct {
Query string
}{Query: string(migration)}); err != nil {
return database.Error{OrigErr: err, Err: "failed to trigger RunPost"}
}

return nil
}
func (ch *ClickHouse) Version() (int, bool, error) {
Expand All @@ -180,7 +213,7 @@ func (ch *ClickHouse) Version() (int, bool, error) {

func (ch *ClickHouse) SetVersion(version int, dirty bool) error {
var (
bool = func(v bool) uint8 {
booln = func(v bool) uint8 {
if v {
return 1
}
Expand All @@ -192,11 +225,25 @@ func (ch *ClickHouse) SetVersion(version int, dirty bool) error {
return err
}

if err := ch.Trigger(database.TrigSetVersionPre, struct {
Version int
Dirty bool
}{Version: version, Dirty: dirty}); err != nil {
return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"}
}

query := "INSERT INTO " + ch.config.MigrationsTable + " (version, dirty, sequence) VALUES (?, ?, ?)"
if _, err := tx.Exec(query, version, bool(dirty), time.Now().UnixNano()); err != nil {
if _, err := tx.Exec(query, version, booln(dirty), time.Now().UnixNano()); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}

if err := ch.Trigger(database.TrigSetVersionPost, struct {
Version int
Dirty bool
}{Version: version, Dirty: dirty}); err != nil {
return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"}
}

return tx.Commit()
}

Expand Down Expand Up @@ -228,9 +275,16 @@ func (ch *ClickHouse) ensureVersionTable() (err error) {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
} else {
if err := ch.Trigger(database.TrigVersionTableExists, nil); err != nil {
return &database.Error{OrigErr: err, Err: "failed to trigger VersionTableExists"}
}
return nil
}

if err := ch.Trigger(database.TrigVersionTablePre, nil); err != nil {
return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"}
}

// if not, create the empty migration table
if len(ch.config.ClusterName) > 0 {
query = fmt.Sprintf(`
Expand All @@ -255,6 +309,11 @@ func (ch *ClickHouse) ensureVersionTable() (err error) {
if _, err := ch.conn.Exec(query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}

if err := ch.Trigger(database.TrigVersionTablePost, nil); err != nil {
return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"}
}

return nil
}

Expand Down Expand Up @@ -306,6 +365,27 @@ func (ch *ClickHouse) Unlock() error {
}
func (ch *ClickHouse) Close() error { return ch.conn.Close() }

func (ch *ClickHouse) AddTriggers(t map[string]func(response interface{}) error) {
ch.config.Triggers = t
}

func (ch *ClickHouse) Trigger(name string, detail interface{}) error {
if ch.config.Triggers == nil {
return nil
}

if trigger, ok := ch.config.Triggers[name]; ok {
return trigger(TriggerResponse{
Driver: ch,
Config: ch.config,
Trigger: name,
Detail: detail,
})
}

return nil
}

// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
func quoteIdentifier(name string) string {
end := strings.IndexRune(name, 0)
Expand Down
Loading