diff --git a/sqlutils/sqlutils.go b/sqlutils/sqlutils.go index 0a2eda2..6b6a43d 100644 --- a/sqlutils/sqlutils.go +++ b/sqlutils/sqlutils.go @@ -26,6 +26,8 @@ import ( "sync" "time" + _ "github.com/go-sql-driver/mysql" + _ "github.com/mattn/go-sqlite3" "github.com/openark/golib/log" ) @@ -170,11 +172,18 @@ func (this *RowMap) GetTime(key string) time.Time { // knownDBs is a DB cache by uri var knownDBs map[string]*sql.DB = make(map[string]*sql.DB) -var knownDBsMutex = &sync.Mutex{} +var knownDBsMutex = &sync.RWMutex{} + +type Logger interface { + OnError(context string, query string, err error) error +} +// it is also protected by knownDBsMutex +var DB2logger map[*sql.DB]Logger = make(map[*sql.DB]Logger) // GetDB returns a DB instance based on uri. +// logger parameter is optional. If nil, internal logging will be used. // bool result indicates whether the DB was returned from cache; err -func GetGenericDB(driverName, dataSourceName string) (*sql.DB, bool, error) { +func GetGenericDB(driverName, dataSourceName string ,logger Logger) (*sql.DB, bool, error) { knownDBsMutex.Lock() defer func() { knownDBsMutex.Unlock() @@ -188,19 +197,23 @@ func GetGenericDB(driverName, dataSourceName string) (*sql.DB, bool, error) { return db, exists, err } } - return knownDBs[dataSourceName], exists, nil + db := knownDBs[dataSourceName] + DB2logger[db] = logger + return db, exists, nil } // GetDB returns a MySQL DB instance based on uri. +// logger parameter is optional. If nil, internal logging will be used. // bool result indicates whether the DB was returned from cache; err -func GetDB(mysql_uri string) (*sql.DB, bool, error) { - return GetGenericDB("mysql", mysql_uri) +func GetDB(mysql_uri string, logger Logger) (*sql.DB, bool, error) { + return GetGenericDB("mysql", mysql_uri, logger) } // GetDB returns a SQLite DB instance based on DB file name. +// logger parameter is optional. If nil, internal logging will be used. // bool result indicates whether the DB was returned from cache; err -func GetSQLiteDB(dbFile string) (*sql.DB, bool, error) { - return GetGenericDB("sqlite3", dbFile) +func GetSQLiteDB(dbFile string, logger Logger) (*sql.DB, bool, error) { + return GetGenericDB("sqlite3", dbFile, logger) } // RowToArray is a convenience function, typically not called directly, which maps a @@ -252,6 +265,20 @@ func ScanRowsToMaps(rows *sql.Rows, on_row func(RowMap) error) error { return err } +func logErrorInternal(context string, db *sql.DB, query string, err error) error{ + // find logger registered by the client + knownDBsMutex.RLock() + defer func() { + knownDBsMutex.RUnlock() + }() + + if logger, exists := DB2logger[db]; exists && logger != nil { + return logger.OnError(context, query, err) + } + + return log.Errore(err) +} + // QueryRowsMap is a convenience function allowing querying a result set while poviding a callback // function activated per read row. func QueryRowsMap(db *sql.DB, query string, on_row func(RowMap) error, args ...interface{}) (err error) { @@ -267,7 +294,7 @@ func QueryRowsMap(db *sql.DB, query string, on_row func(RowMap) error, args ...i defer rows.Close() } if err != nil && err != sql.ErrNoRows { - return log.Errore(err) + return logErrorInternal("QueryRowsMap", db, query, err) } err = ScanRowsToMaps(rows, on_row) return @@ -285,7 +312,7 @@ func queryResultData(db *sql.DB, query string, retrieveColumns bool, args ...int rows, err = db.Query(query, args...) defer rows.Close() if err != nil && err != sql.ErrNoRows { - return EmptyResultData, columns, err + return EmptyResultData, columns, logErrorInternal("queryResultData", db, query, err) } if retrieveColumns { // Don't pay if you don't want to @@ -339,7 +366,7 @@ func ExecNoPrepare(db *sql.DB, query string, args ...interface{}) (res sql.Resul res, err = db.Exec(query, args...) if err != nil { - log.Errore(err) + logErrorInternal("ExecNoPrepare", db, query, err) } return res, err } @@ -360,7 +387,7 @@ func execInternal(silent bool, db *sql.DB, query string, args ...interface{}) (r defer stmt.Close() res, err = stmt.Exec(args...) if err != nil && !silent { - log.Errore(err) + logErrorInternal("execInternal", db, query, err) } return res, err }