Skip to content

Commit 7ae613a

Browse files
committed
perf(validation): optimize field overlap check
1 parent 5ca91cb commit 7ae613a

9 files changed

+429
-23
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
[Unreleased]
44

55
* [BUGFIX] Reject object, interface, and input object type definitions that declare zero fields/input values (spec compliance).
6+
* [IMPROVEMENT] Optimize overlapping field validation to avoid quadratic memory blowups on large sibling field lists.
7+
* [FEATURE] Add configurable safety valve for overlapping field comparison count with `OverlapValidationLimit(n)` schema option (0 disables the cap). When exceeded validation aborts early with rule `OverlapValidationLimitExceeded`.
8+
* [TEST] Add stress benchmarks & randomized overlap stress test for mixed field/fragment patterns.
69

710
[v1.7.0](https://github.com/graph-gophers/graphql-go/releases/tag/v1.7.0) Release v1.7.0
811

graphql.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ type Schema struct {
8686
subscribeResolverTimeout time.Duration
8787
useFieldResolvers bool
8888
disableFieldSelections bool
89+
overlapPairLimit int
8990
}
9091

9192
// AST returns the abstract syntax tree of the GraphQL schema definition.
@@ -152,6 +153,14 @@ func MaxQueryLength(n int) SchemaOpt {
152153
}
153154
}
154155

156+
// OverlapValidationLimit caps the number of overlapping selection pairs that will be examined
157+
// during validation of a single operation (including fragments). A value of 0 disables the cap.
158+
// When the cap is exceeded validation aborts early with an error (rule: OverlapValidationLimitExceeded)
159+
// to protect against maliciously constructed queries designed to exhaust memory/CPU.
160+
func OverlapValidationLimit(n int) SchemaOpt {
161+
return func(s *Schema) { s.overlapPairLimit = n }
162+
}
163+
155164
// Tracer is used to trace queries and fields. It defaults to [noop.Tracer].
156165
func Tracer(t tracer.Tracer) SchemaOpt {
157166
return func(s *Schema) {
@@ -247,7 +256,7 @@ func (s *Schema) ValidateWithVariables(queryString string, variables map[string]
247256
return []*errors.QueryError{errors.Errorf("executable document must contain at least one operation")}
248257
}
249258

250-
return validation.Validate(s.schema, doc, variables, s.maxDepth)
259+
return validation.Validate(s.schema, doc, variables, s.maxDepth, s.overlapPairLimit)
251260
}
252261

253262
// Exec executes the given query with the schema's resolver. It panics if the schema was created
@@ -270,7 +279,7 @@ func (s *Schema) exec(ctx context.Context, queryString string, operationName str
270279
}
271280

272281
validationFinish := s.validationTracer.TraceValidation(ctx)
273-
errs := validation.Validate(s.schema, doc, variables, s.maxDepth)
282+
errs := validation.Validate(s.schema, doc, variables, s.maxDepth, s.overlapPairLimit)
274283
validationFinish(errs)
275284
if len(errs) != 0 {
276285
return &Response{Errors: errs}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package validation_test
2+
3+
import (
4+
"math/rand"
5+
"testing"
6+
"time"
7+
8+
"github.com/graph-gophers/graphql-go/internal/query"
9+
"github.com/graph-gophers/graphql-go/internal/schema"
10+
v "github.com/graph-gophers/graphql-go/internal/validation"
11+
)
12+
13+
// FuzzValidateOverlapMixed exercises the overlap validation logic with randomly generated queries
14+
// containing many sibling fields and fragment spreads to ensure it does not panic or explode in memory.
15+
// It uses a modest overlap pair cap to keep each iteration bounded.
16+
func FuzzValidateOverlapMixed(f *testing.F) {
17+
baseQueries := []string{
18+
"query{root{id}}",
19+
"query Q{root{id name}}",
20+
}
21+
for _, q := range baseQueries {
22+
f.Add(q)
23+
}
24+
25+
s := schema.New()
26+
_ = schema.Parse(s, `schema{query:Query} type Query{root: Thing} type Thing { id: ID name: String value: String }`, false)
27+
28+
randSource := rand.New(rand.NewSource(time.Now().UnixNano()))
29+
30+
f.Fuzz(func(t *testing.T, seed string) {
31+
// Use hash of seed to deterministically generate but bound complexity.
32+
r := rand.New(rand.NewSource(int64(len(seed)) + randSource.Int63()))
33+
fieldCount := 50 + r.Intn(150) // 50-199
34+
fragCount := 1 + r.Intn(5)
35+
36+
// Build fragments.
37+
fragBodies := make([]string, fragCount)
38+
for i := 0; i < fragCount; i++ {
39+
// each fragment gets subset of fields
40+
var body string
41+
innerFields := 5 + r.Intn(20)
42+
for j := 0; j < innerFields; j++ {
43+
body += " f" + nameIdx(r.Intn(500)) + ":id"
44+
}
45+
fragBodies[i] = "fragment F" + nameIdx(i) + " on Thing{" + body + " }"
46+
}
47+
48+
// Root selection
49+
sel := "query{root{"
50+
for i := 0; i < fieldCount; i++ {
51+
sel += " a" + nameIdx(r.Intn(1000)) + ":id"
52+
}
53+
// Sprinkle fragment spreads
54+
for i := 0; i < fragCount; i++ {
55+
sel += " ...F" + nameIdx(i)
56+
}
57+
sel += "}}"
58+
queryText := sel
59+
for _, fb := range fragBodies {
60+
queryText += fb
61+
}
62+
63+
doc, err := query.Parse(queryText)
64+
if err != nil {
65+
return
66+
} // parser fuzzing not our goal
67+
if len(doc.Operations) == 0 {
68+
return
69+
}
70+
// Use overlap limit to bound cost.
71+
errs := v.Validate(s, doc, nil, 0, 10_000)
72+
// Ensure no panic (implicit). Optionally sanity check: errors slice must not be ridiculously huge.
73+
if len(errs) > 1000 {
74+
t.Fatalf("too many errors: %d", len(errs))
75+
}
76+
})
77+
}
78+
79+
func nameIdx(i int) string {
80+
const letters = "abcdefghijklmnopqrstuvwxyz"
81+
if i < len(letters) {
82+
return string(letters[i])
83+
}
84+
return string(letters[i%len(letters)]) + nameIdx(i/len(letters))
85+
}

internal/validation/validate_max_depth_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ func (tc maxDepthTestCase) Run(t *testing.T, s *ast.Schema) {
8383
t.Fatal(qErr)
8484
}
8585

86-
errs := Validate(s, doc, nil, tc.depth)
86+
errs := Validate(s, doc, nil, tc.depth, 0)
8787
if len(tc.expectedErrors) > 0 {
8888
if len(errs) > 0 {
8989
for _, expected := range tc.expectedErrors {
@@ -489,7 +489,7 @@ func TestMaxDepthValidation(t *testing.T) {
489489
t.Fatal(err)
490490
}
491491

492-
context := newContext(s, doc, tc.maxDepth)
492+
context := newContext(s, doc, tc.maxDepth, 0)
493493
op := doc.Operations[0]
494494

495495
opc := &opContext{context: context, ops: doc.Operations}

internal/validation/validation.go

Lines changed: 154 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,17 @@ type fieldInfo struct {
2626
}
2727

2828
type context struct {
29-
schema *ast.Schema
30-
doc *ast.ExecutableDefinition
31-
errs []*errors.QueryError
32-
opErrs map[*ast.OperationDefinition][]*errors.QueryError
33-
usedVars map[*ast.OperationDefinition]varSet
34-
fieldMap map[*ast.Field]fieldInfo
35-
overlapValidated map[selectionPair]struct{}
36-
maxDepth int
29+
schema *ast.Schema
30+
doc *ast.ExecutableDefinition
31+
errs []*errors.QueryError
32+
opErrs map[*ast.OperationDefinition][]*errors.QueryError
33+
usedVars map[*ast.OperationDefinition]varSet
34+
fieldMap map[*ast.Field]fieldInfo
35+
overlapValidated map[selectionPair]struct{}
36+
maxDepth int
37+
overlapPairLimit int
38+
overlapPairsObserved int
39+
overlapLimitHit bool
3740
}
3841

3942
func (c *context) addErr(loc errors.Location, rule string, format string, a ...interface{}) {
@@ -53,7 +56,7 @@ type opContext struct {
5356
ops []*ast.OperationDefinition
5457
}
5558

56-
func newContext(s *ast.Schema, doc *ast.ExecutableDefinition, maxDepth int) *context {
59+
func newContext(s *ast.Schema, doc *ast.ExecutableDefinition, maxDepth int, overlapPairLimit int) *context {
5760
return &context{
5861
schema: s,
5962
doc: doc,
@@ -62,11 +65,12 @@ func newContext(s *ast.Schema, doc *ast.ExecutableDefinition, maxDepth int) *con
6265
fieldMap: make(map[*ast.Field]fieldInfo),
6366
overlapValidated: make(map[selectionPair]struct{}),
6467
maxDepth: maxDepth,
68+
overlapPairLimit: overlapPairLimit,
6569
}
6670
}
6771

68-
func Validate(s *ast.Schema, doc *ast.ExecutableDefinition, variables map[string]interface{}, maxDepth int) []*errors.QueryError {
69-
c := newContext(s, doc, maxDepth)
72+
func Validate(s *ast.Schema, doc *ast.ExecutableDefinition, variables map[string]interface{}, maxDepth int, overlapPairLimit int) []*errors.QueryError {
73+
c := newContext(s, doc, maxDepth, overlapPairLimit)
7074

7175
opNames := make(nameSet, len(doc.Operations))
7276
fragUsedBy := make(map[*ast.FragmentDefinition][]*ast.OperationDefinition)
@@ -303,13 +307,76 @@ func validateMaxDepth(c *opContext, sels []ast.Selection, visited map[*ast.Fragm
303307
}
304308

305309
func validateSelectionSet(c *opContext, sels []ast.Selection, t ast.NamedType) {
310+
if len(sels) == 0 {
311+
return
312+
}
313+
314+
// First pass: validate each selection and bucket fields by response name (alias or name).
315+
fieldGroups := make(map[string][]ast.Selection)
316+
var fragments []ast.Selection // fragment spreads & inline fragments
306317
for _, sel := range sels {
318+
if c.overlapLimitHit {
319+
return
320+
}
307321
validateSelection(c, sel, t)
322+
switch s := sel.(type) {
323+
case *ast.Field:
324+
name := s.Alias.Name
325+
if name == "" {
326+
name = s.Name.Name
327+
}
328+
fieldGroups[name] = append(fieldGroups[name], sel)
329+
default:
330+
fragments = append(fragments, sel)
331+
}
308332
}
309333

310-
for i, a := range sels {
311-
for _, b := range sels[i+1:] {
312-
c.validateOverlap(a, b, nil, nil)
334+
// Compare fields only within same response name group (was O(n^2) across all fields previously).
335+
for _, group := range fieldGroups {
336+
if c.overlapLimitHit {
337+
break
338+
}
339+
if len(group) < 2 {
340+
continue
341+
}
342+
for i, a := range group {
343+
if c.overlapLimitHit {
344+
break
345+
}
346+
for _, b := range group[i+1:] {
347+
if c.overlapLimitHit {
348+
break
349+
}
350+
c.validateOverlap(a, b, nil, nil)
351+
}
352+
}
353+
}
354+
355+
// Fragments can introduce any field names, so we must compare them with all fields and each other.
356+
if len(fragments) > 0 && !c.overlapLimitHit {
357+
// Flatten fields for fragment comparison.
358+
var allFields []ast.Selection
359+
for _, group := range fieldGroups {
360+
allFields = append(allFields, group...)
361+
}
362+
for i, fa := range fragments {
363+
if c.overlapLimitHit {
364+
break
365+
}
366+
// Compare fragment with all fields
367+
for _, fld := range allFields {
368+
if c.overlapLimitHit {
369+
break
370+
}
371+
c.validateOverlap(fa, fld, nil, nil)
372+
}
373+
// Compare fragment with following fragments
374+
for _, fb := range fragments[i+1:] {
375+
if c.overlapLimitHit {
376+
break
377+
}
378+
c.validateOverlap(fa, fb, nil, nil)
379+
}
313380
}
314381
}
315382
}
@@ -523,11 +590,38 @@ func (c *context) validateOverlap(a, b ast.Selection, reasons *[]string, locs *[
523590
return
524591
}
525592

526-
if _, ok := c.overlapValidated[selectionPair{a, b}]; ok {
593+
// Optimisation 1: store only one direction of the pair to halve memory and lookups.
594+
pa := reflect.ValueOf(a).Pointer()
595+
pb := reflect.ValueOf(b).Pointer()
596+
if pb < pa { // canonical ordering
597+
a, b = b, a
598+
}
599+
key := selectionPair{a: a, b: b}
600+
if _, ok := c.overlapValidated[key]; ok {
527601
return
528602
}
529-
c.overlapValidated[selectionPair{a, b}] = struct{}{}
530-
c.overlapValidated[selectionPair{b, a}] = struct{}{}
603+
c.overlapValidated[key] = struct{}{}
604+
605+
if c.overlapPairLimit > 0 && !c.overlapLimitHit {
606+
c.overlapPairsObserved++
607+
if c.overlapPairsObserved > c.overlapPairLimit {
608+
c.overlapLimitHit = true
609+
// determine a representative location for error reporting
610+
var loc errors.Location
611+
switch sel := a.(type) {
612+
case *ast.Field:
613+
loc = sel.Alias.Loc
614+
case *ast.InlineFragment:
615+
loc = sel.Loc
616+
case *ast.FragmentSpread:
617+
loc = sel.Loc
618+
default:
619+
// leave zero value
620+
}
621+
c.addErr(loc, "OverlapValidationLimitExceeded", "Overlapping field validation aborted after examining %d pairs (limit %d). Consider restructuring the query or increasing the limit.", c.overlapPairsObserved-1, c.overlapPairLimit)
622+
return
623+
}
624+
}
531625

532626
switch a := a.(type) {
533627
case *ast.Field:
@@ -608,11 +702,54 @@ func (c *context) validateFieldOverlap(a, b *ast.Field) ([]string, []errors.Loca
608702

609703
var reasons []string
610704
var locs []errors.Location
705+
706+
// Fast-path: if either side has no subselections we are done.
707+
if len(a.SelectionSet) == 0 || len(b.SelectionSet) == 0 {
708+
return nil, nil
709+
}
710+
711+
// Optimisation 2: avoid O(m*n) cartesian product for large sibling lists with mostly
712+
// distinct response names (common & exploitable for DoS). Instead, index B's field
713+
// selections by response name (alias/name). For each field in A we only compare
714+
// against fields in B with the same response name plus all fragment spreads / inline
715+
// fragments (which can expand to any field names and must be compared exhaustively).
716+
bFieldIndex := make(map[string][]ast.Selection, len(b.SelectionSet))
717+
var bNonField []ast.Selection
718+
for _, bs := range b.SelectionSet {
719+
if f, ok := bs.(*ast.Field); ok {
720+
name := f.Alias.Name
721+
if name == "" { // alias may be empty, fall back to field name
722+
name = f.Name.Name
723+
}
724+
bFieldIndex[name] = append(bFieldIndex[name], bs)
725+
continue
726+
}
727+
bNonField = append(bNonField, bs)
728+
}
729+
611730
for _, a2 := range a.SelectionSet {
731+
if af, ok := a2.(*ast.Field); ok {
732+
name := af.Alias.Name
733+
if name == "" {
734+
name = af.Name.Name
735+
}
736+
// Compare only against same-name fields + all non-field selections.
737+
if matches := bFieldIndex[name]; len(matches) != 0 {
738+
for _, bMatch := range matches {
739+
c.validateOverlap(a2, bMatch, &reasons, &locs)
740+
}
741+
}
742+
for _, bnf := range bNonField {
743+
c.validateOverlap(a2, bnf, &reasons, &locs)
744+
}
745+
continue
746+
}
747+
// For fragments / inline fragments we still need to compare against every selection in B.
612748
for _, b2 := range b.SelectionSet {
613749
c.validateOverlap(a2, b2, &reasons, &locs)
614750
}
615751
}
752+
616753
return reasons, locs
617754
}
618755

internal/validation/validation_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ func TestValidate(t *testing.T) {
8888
if err != nil {
8989
t.Fatalf("failed to parse query: %s", err)
9090
}
91-
errs := validation.Validate(schemas[test.Schema], d, test.Vars, 0)
91+
errs := validation.Validate(schemas[test.Schema], d, test.Vars, 0, 0)
9292
got := []*errors.QueryError{}
9393
for _, err := range errs {
9494
if err.Rule == test.Rule {

0 commit comments

Comments
 (0)