Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions sqlutils/sqlutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
"sync"
"time"

_ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3"
"github.com/openark/golib/log"
)

Expand Down Expand Up @@ -170,11 +172,18 @@ func (this *RowMap) GetTime(key string) time.Time {

// knownDBs is a DB cache by uri
var knownDBs map[string]*sql.DB = make(map[string]*sql.DB)
var knownDBsMutex = &sync.Mutex{}
var knownDBsMutex = &sync.RWMutex{}

type Logger interface {
OnError(context string, query string, err error) error
}
// it is also protected by knownDBsMutex
var DB2logger map[*sql.DB]Logger = make(map[*sql.DB]Logger)

// GetDB returns a DB instance based on uri.
// logger parameter is optional. If nil, internal logging will be used.
// bool result indicates whether the DB was returned from cache; err
func GetGenericDB(driverName, dataSourceName string) (*sql.DB, bool, error) {
func GetGenericDB(driverName, dataSourceName string ,logger Logger) (*sql.DB, bool, error) {
knownDBsMutex.Lock()
defer func() {
knownDBsMutex.Unlock()
Expand All @@ -188,19 +197,23 @@ func GetGenericDB(driverName, dataSourceName string) (*sql.DB, bool, error) {
return db, exists, err
}
}
return knownDBs[dataSourceName], exists, nil
db := knownDBs[dataSourceName]
DB2logger[db] = logger
return db, exists, nil
}

// GetDB returns a MySQL DB instance based on uri.
// logger parameter is optional. If nil, internal logging will be used.
// bool result indicates whether the DB was returned from cache; err
func GetDB(mysql_uri string) (*sql.DB, bool, error) {
return GetGenericDB("mysql", mysql_uri)
func GetDB(mysql_uri string, logger Logger) (*sql.DB, bool, error) {
return GetGenericDB("mysql", mysql_uri, logger)
}

// GetDB returns a SQLite DB instance based on DB file name.
// logger parameter is optional. If nil, internal logging will be used.
// bool result indicates whether the DB was returned from cache; err
func GetSQLiteDB(dbFile string) (*sql.DB, bool, error) {
return GetGenericDB("sqlite3", dbFile)
func GetSQLiteDB(dbFile string, logger Logger) (*sql.DB, bool, error) {
return GetGenericDB("sqlite3", dbFile, logger)
}

// RowToArray is a convenience function, typically not called directly, which maps a
Expand Down Expand Up @@ -252,6 +265,20 @@ func ScanRowsToMaps(rows *sql.Rows, on_row func(RowMap) error) error {
return err
}

func logErrorInternal(context string, db *sql.DB, query string, err error) error{
// find logger registered by the client
knownDBsMutex.RLock()
defer func() {
knownDBsMutex.RUnlock()
}()

if logger, exists := DB2logger[db]; exists && logger != nil {
return logger.OnError(context, query, err)
}

return log.Errore(err)
}

// QueryRowsMap is a convenience function allowing querying a result set while poviding a callback
// function activated per read row.
func QueryRowsMap(db *sql.DB, query string, on_row func(RowMap) error, args ...interface{}) (err error) {
Expand All @@ -267,7 +294,7 @@ func QueryRowsMap(db *sql.DB, query string, on_row func(RowMap) error, args ...i
defer rows.Close()
}
if err != nil && err != sql.ErrNoRows {
return log.Errore(err)
return logErrorInternal("QueryRowsMap", db, query, err)
}
err = ScanRowsToMaps(rows, on_row)
return
Expand All @@ -285,7 +312,7 @@ func queryResultData(db *sql.DB, query string, retrieveColumns bool, args ...int
rows, err = db.Query(query, args...)
defer rows.Close()
if err != nil && err != sql.ErrNoRows {
return EmptyResultData, columns, err
return EmptyResultData, columns, logErrorInternal("queryResultData", db, query, err)
}
if retrieveColumns {
// Don't pay if you don't want to
Expand Down Expand Up @@ -339,7 +366,7 @@ func ExecNoPrepare(db *sql.DB, query string, args ...interface{}) (res sql.Resul

res, err = db.Exec(query, args...)
if err != nil {
log.Errore(err)
logErrorInternal("ExecNoPrepare", db, query, err)
}
return res, err
}
Expand All @@ -360,7 +387,7 @@ func execInternal(silent bool, db *sql.DB, query string, args ...interface{}) (r
defer stmt.Close()
res, err = stmt.Exec(args...)
if err != nil && !silent {
log.Errore(err)
logErrorInternal("execInternal", db, query, err)
}
return res, err
}
Expand Down