Skip to content

Commit 9b368a7

Browse files
author
Shlomi Noach
committed
Merge pull request #8 from github/sql-queries-manipulations
merging so I can use this on other branches
2 parents 20a74d5 + 39ebc75 commit 9b368a7

File tree

3 files changed

+309
-0
lines changed

3 files changed

+309
-0
lines changed

go/base/context.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
/*
2+
Copyright 2016 GitHub Inc.
3+
See https://github.com/github/gh-osc/blob/master/LICENSE
4+
*/
5+
6+
package base
7+
8+
import ()
9+
10+
type MigrationContext struct {
11+
DatabaseName string
12+
OriginalTableName string
13+
GhostTableName string
14+
}

go/sql/builder.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/*
2+
Copyright 2016 GitHub Inc.
3+
See https://github.com/github/gh-osc/blob/master/LICENSE
4+
*/
5+
6+
package sql
7+
8+
import (
9+
"fmt"
10+
"strconv"
11+
"strings"
12+
)
13+
14+
type ValueComparisonSign string
15+
16+
const (
17+
LessThanComparisonSign ValueComparisonSign = "<"
18+
LessThanOrEqualsComparisonSign = "<="
19+
EqualsComparisonSign = "="
20+
GreaterThanOrEqualsComparisonSign = ">="
21+
GreaterThanComparisonSign = ">"
22+
NotEqualsComparisonSign = "!="
23+
)
24+
25+
// EscapeName will escape a db/table/column/... name by wrapping with backticks.
26+
// It is not fool proof. I'm just trying to do the right thing here, not solving
27+
// SQL injection issues, which should be irrelevant for this tool.
28+
func EscapeName(name string) string {
29+
if unquoted, err := strconv.Unquote(name); err == nil {
30+
name = unquoted
31+
}
32+
return fmt.Sprintf("`%s`", name)
33+
}
34+
35+
func BuildValueComparison(column string, value string, comparisonSign ValueComparisonSign) (result string, err error) {
36+
if column == "" {
37+
return "", fmt.Errorf("Empty column in GetValueComparison")
38+
}
39+
if value == "" {
40+
return "", fmt.Errorf("Empty value in GetValueComparison")
41+
}
42+
comparison := fmt.Sprintf("(%s %s %s)", EscapeName(column), string(comparisonSign), value)
43+
return comparison, err
44+
}
45+
46+
func BuildEqualsComparison(columns []string, values []string) (result string, err error) {
47+
if len(columns) == 0 {
48+
return "", fmt.Errorf("Got 0 columns in GetEqualsComparison")
49+
}
50+
if len(columns) != len(values) {
51+
return "", fmt.Errorf("Got %d columns but %d values in GetEqualsComparison", len(columns), len(values))
52+
}
53+
comparisons := []string{}
54+
for i, column := range columns {
55+
value := values[i]
56+
comparison, err := BuildValueComparison(column, value, EqualsComparisonSign)
57+
if err != nil {
58+
return "", err
59+
}
60+
comparisons = append(comparisons, comparison)
61+
}
62+
result = strings.Join(comparisons, " and ")
63+
result = fmt.Sprintf("(%s)", result)
64+
return result, nil
65+
}
66+
67+
func BuildRangeComparison(columns []string, values []string, comparisonSign ValueComparisonSign) (result string, err error) {
68+
if len(columns) == 0 {
69+
return "", fmt.Errorf("Got 0 columns in GetRangeComparison")
70+
}
71+
if len(columns) != len(values) {
72+
return "", fmt.Errorf("Got %d columns but %d values in GetEqualsComparison", len(columns), len(values))
73+
}
74+
includeEquals := false
75+
if comparisonSign == LessThanOrEqualsComparisonSign {
76+
comparisonSign = LessThanComparisonSign
77+
includeEquals = true
78+
}
79+
if comparisonSign == GreaterThanOrEqualsComparisonSign {
80+
comparisonSign = GreaterThanComparisonSign
81+
includeEquals = true
82+
}
83+
comparisons := []string{}
84+
85+
for i, column := range columns {
86+
//
87+
value := values[i]
88+
rangeComparison, err := BuildValueComparison(column, value, comparisonSign)
89+
if err != nil {
90+
return "", err
91+
}
92+
if len(columns[0:i]) > 0 {
93+
equalitiesComparison, err := BuildEqualsComparison(columns[0:i], values[0:i])
94+
if err != nil {
95+
return "", err
96+
}
97+
comparison := fmt.Sprintf("(%s AND %s)", equalitiesComparison, rangeComparison)
98+
comparisons = append(comparisons, comparison)
99+
} else {
100+
comparisons = append(comparisons, rangeComparison)
101+
}
102+
}
103+
104+
if includeEquals {
105+
comparison, err := BuildEqualsComparison(columns, values)
106+
if err != nil {
107+
return "", nil
108+
}
109+
comparisons = append(comparisons, comparison)
110+
}
111+
result = strings.Join(comparisons, " or ")
112+
result = fmt.Sprintf("(%s)", result)
113+
return result, nil
114+
}
115+
116+
func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, uniqueKey string, uniqueKeyColumns, rangeStartValues, rangeEndValues []string) (string, error) {
117+
if len(sharedColumns) == 0 {
118+
return "", fmt.Errorf("Got 0 shared columns in BuildRangeInsertQuery")
119+
}
120+
sharedColumnsListing := strings.Join(sharedColumns, ", ")
121+
rangeStartComparison, err := BuildRangeComparison(uniqueKeyColumns, rangeStartValues, GreaterThanOrEqualsComparisonSign)
122+
if err != nil {
123+
return "", err
124+
}
125+
rangeEndComparison, err := BuildRangeComparison(uniqueKeyColumns, rangeEndValues, LessThanOrEqualsComparisonSign)
126+
if err != nil {
127+
return "", err
128+
}
129+
query := fmt.Sprintf(`
130+
insert /* gh-osc %s.%s */ ignore into %s.%s (%s)
131+
(select %s from %s.%s force index (%s)
132+
where (%s and %s)
133+
)
134+
`, databaseName, originalTableName, databaseName, ghostTableName, sharedColumnsListing,
135+
sharedColumnsListing, databaseName, originalTableName, uniqueKey,
136+
rangeStartComparison, rangeEndComparison)
137+
return query, nil
138+
}

go/sql/builder_test.go

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
/*
2+
Copyright 2016 GitHub Inc.
3+
See https://github.com/github/gh-osc/blob/master/LICENSE
4+
*/
5+
6+
package sql
7+
8+
import (
9+
"testing"
10+
11+
"regexp"
12+
"strings"
13+
14+
"github.com/outbrain/golib/log"
15+
test "github.com/outbrain/golib/tests"
16+
)
17+
18+
var (
19+
spacesRegexp = regexp.MustCompile(`[ \t\n\r]+`)
20+
)
21+
22+
func init() {
23+
log.SetLevel(log.ERROR)
24+
}
25+
26+
func normalizeQuery(name string) string {
27+
name = strings.Replace(name, "`", "", -1)
28+
name = spacesRegexp.ReplaceAllString(name, " ")
29+
name = strings.TrimSpace(name)
30+
return name
31+
}
32+
33+
func TestEscapeName(t *testing.T) {
34+
names := []string{"my_table", `"my_table"`, "`my_table`"}
35+
for _, name := range names {
36+
escaped := EscapeName(name)
37+
test.S(t).ExpectEquals(escaped, "`my_table`")
38+
}
39+
}
40+
41+
func TestBuildEqualsComparison(t *testing.T) {
42+
{
43+
columns := []string{"c1"}
44+
values := []string{"@v1"}
45+
comparison, err := BuildEqualsComparison(columns, values)
46+
test.S(t).ExpectNil(err)
47+
test.S(t).ExpectEquals(comparison, "((`c1` = @v1))")
48+
}
49+
{
50+
columns := []string{"c1", "c2"}
51+
values := []string{"@v1", "@v2"}
52+
comparison, err := BuildEqualsComparison(columns, values)
53+
test.S(t).ExpectNil(err)
54+
test.S(t).ExpectEquals(comparison, "((`c1` = @v1) and (`c2` = @v2))")
55+
}
56+
{
57+
columns := []string{"c1"}
58+
values := []string{"@v1", "@v2"}
59+
_, err := BuildEqualsComparison(columns, values)
60+
test.S(t).ExpectNotNil(err)
61+
}
62+
{
63+
columns := []string{}
64+
values := []string{}
65+
_, err := BuildEqualsComparison(columns, values)
66+
test.S(t).ExpectNotNil(err)
67+
}
68+
}
69+
70+
func TestBuildRangeComparison(t *testing.T) {
71+
{
72+
columns := []string{"c1"}
73+
values := []string{"@v1"}
74+
comparison, err := BuildRangeComparison(columns, values, LessThanComparisonSign)
75+
test.S(t).ExpectNil(err)
76+
test.S(t).ExpectEquals(comparison, "((`c1` < @v1))")
77+
}
78+
{
79+
columns := []string{"c1"}
80+
values := []string{"@v1"}
81+
comparison, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign)
82+
test.S(t).ExpectNil(err)
83+
test.S(t).ExpectEquals(comparison, "((`c1` < @v1) or ((`c1` = @v1)))")
84+
}
85+
{
86+
columns := []string{"c1", "c2"}
87+
values := []string{"@v1", "@v2"}
88+
comparison, err := BuildRangeComparison(columns, values, LessThanComparisonSign)
89+
test.S(t).ExpectNil(err)
90+
test.S(t).ExpectEquals(comparison, "((`c1` < @v1) or (((`c1` = @v1)) AND (`c2` < @v2)))")
91+
}
92+
{
93+
columns := []string{"c1", "c2"}
94+
values := []string{"@v1", "@v2"}
95+
comparison, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign)
96+
test.S(t).ExpectNil(err)
97+
test.S(t).ExpectEquals(comparison, "((`c1` < @v1) or (((`c1` = @v1)) AND (`c2` < @v2)) or ((`c1` = @v1) and (`c2` = @v2)))")
98+
}
99+
{
100+
columns := []string{"c1", "c2", "c3"}
101+
values := []string{"@v1", "@v2", "@v3"}
102+
comparison, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign)
103+
test.S(t).ExpectNil(err)
104+
test.S(t).ExpectEquals(comparison, "((`c1` < @v1) or (((`c1` = @v1)) AND (`c2` < @v2)) or (((`c1` = @v1) and (`c2` = @v2)) AND (`c3` < @v3)) or ((`c1` = @v1) and (`c2` = @v2) and (`c3` = @v3)))")
105+
}
106+
{
107+
columns := []string{"c1"}
108+
values := []string{"@v1", "@v2"}
109+
_, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign)
110+
test.S(t).ExpectNotNil(err)
111+
}
112+
{
113+
columns := []string{}
114+
values := []string{}
115+
_, err := BuildRangeComparison(columns, values, LessThanOrEqualsComparisonSign)
116+
test.S(t).ExpectNotNil(err)
117+
}
118+
}
119+
120+
func TestBuildRangeInsertQuery(t *testing.T) {
121+
databaseName := "mydb"
122+
originalTableName := "tbl"
123+
ghostTableName := "ghost"
124+
sharedColumns := []string{"id", "name", "position"}
125+
{
126+
uniqueKey := "PRIMARY"
127+
uniqueKeyColumns := []string{"id"}
128+
rangeStartValues := []string{"@v1s"}
129+
rangeEndValues := []string{"@v1e"}
130+
131+
query, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues)
132+
test.S(t).ExpectNil(err)
133+
expected := `
134+
insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position)
135+
(select id, name, position from mydb.tbl force index (PRIMARY)
136+
where (((id > @v1s) or ((id = @v1s))) and ((id < @v1e) or ((id = @v1e))))
137+
)
138+
`
139+
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
140+
}
141+
{
142+
uniqueKey := "name_position_uidx"
143+
uniqueKeyColumns := []string{"name", "position"}
144+
rangeStartValues := []string{"@v1s", "@v2s"}
145+
rangeEndValues := []string{"@v1e", "@v2e"}
146+
147+
query, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues)
148+
test.S(t).ExpectNil(err)
149+
expected := `
150+
insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position)
151+
(select id, name, position from mydb.tbl force index (name_position_uidx)
152+
where (((name > @v1s) or (((name = @v1s)) AND (position > @v2s)) or ((name = @v1s) and (position = @v2s))) and ((name < @v1e) or (((name = @v1e)) AND (position < @v2e)) or ((name = @v1e) and (position = @v2e))))
153+
)
154+
`
155+
test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected))
156+
}
157+
}

0 commit comments

Comments
 (0)