Skip to content

Commit f4dfc49

Browse files
committed
Add repetition detection to the GraphQL package
Signed-off-by: peterdeme <demepeter93@gmail.com>
1 parent d028a05 commit f4dfc49

3 files changed

Lines changed: 127 additions & 3 deletions

File tree

ddos_vulnerability_test.go

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,16 +175,15 @@ func TestDDoSVulnerability_WithFix(t *testing.T) {
175175
t.Fatal("Expected validation errors, but got none")
176176
}
177177

178-
// Check that the error message mentions the max selection set size
179178
found := false
180179
for _, err := range errors {
181-
if strings.Contains(err.Message, "exceeds the maximum allowed size") {
180+
if strings.Contains(err.Message, "invalid query") {
182181
found = true
183182
break
184183
}
185184
}
186185
if !found {
187-
t.Errorf("Expected error about exceeding max selection set size, but got: %v", errors)
186+
t.Errorf("Expected error about exceeding max selection set size or repeated tokens, but got: %v", errors)
188187
}
189188

190189
// Validation should be fast (< 100ms)
@@ -215,3 +214,99 @@ func TestDDoSVulnerability_FixWithReasonableQuery(t *testing.T) {
215214
}
216215
}
217216
}
217+
218+
// TestDDoSVulnerability_RepeatedTokenDetection tests early detection of repeated field patterns
219+
func TestDDoSVulnerability_RepeatedTokenDetection(t *testing.T) {
220+
// Create a query with 150 repeated "a" fields - should be rejected during parsing
221+
numFields := 150
222+
fields := make([]string, numFields)
223+
for i := 0; i < numFields; i++ {
224+
fields[i] = "a"
225+
}
226+
227+
maliciousQuery := "query { " + strings.Join(fields, " ") + " }"
228+
229+
schema := graphql.MustParseSchema(simpleSchema, &simpleResolver{})
230+
231+
start := time.Now()
232+
result := schema.Exec(context.Background(), maliciousQuery, "", nil)
233+
duration := time.Since(start)
234+
235+
t.Logf("Query with %d repeated tokens took %v", numFields, duration)
236+
t.Logf("Errors: %v", result.Errors)
237+
238+
// Should have errors
239+
if len(result.Errors) == 0 {
240+
t.Fatal("Expected errors for repeated token attack, but got none")
241+
}
242+
243+
// Should be rejected quickly (parsing phase)
244+
if duration > 100*time.Millisecond {
245+
t.Errorf("Query took too long (%v) - should be rejected during parsing", duration)
246+
}
247+
248+
// Check error message mentions repetition or DDoS
249+
foundRelevantError := false
250+
for _, err := range result.Errors {
251+
if strings.Contains(err.Message, "syntax error: invalid query") {
252+
foundRelevantError = true
253+
t.Logf("Found expected error: %s", err.Message)
254+
break
255+
}
256+
}
257+
258+
if !foundRelevantError {
259+
t.Errorf("Expected error about repeated tokens, got: %v", result.Errors)
260+
}
261+
}
262+
263+
// TestDDoSVulnerability_RepeatedTokenAtThreshold tests behavior at the threshold
264+
func TestDDoSVulnerability_RepeatedTokenAtThreshold(t *testing.T) {
265+
// Create a query with exactly 20 repeated "a" fields - should be allowed
266+
numFields := 20
267+
fields := make([]string, numFields)
268+
for i := 0; i < numFields; i++ {
269+
fields[i] = "a"
270+
}
271+
272+
queryAtThreshold := "query { " + strings.Join(fields, " ") + " }"
273+
274+
schema := graphql.MustParseSchema(simpleSchema, &simpleResolver{})
275+
276+
start := time.Now()
277+
result := schema.Exec(context.Background(), queryAtThreshold, "", nil)
278+
duration := time.Since(start)
279+
280+
t.Logf("Query with %d repeated tokens (at threshold) took %v", numFields, duration)
281+
t.Logf("Errors: %v", result.Errors)
282+
283+
// Should be fast
284+
if duration > 100*time.Millisecond {
285+
t.Errorf("Query took too long (%v)", duration)
286+
}
287+
288+
if len(result.Errors) > 0 {
289+
t.Errorf("Expected no errors for query at threshold, but got: %v", result.Errors)
290+
}
291+
}
292+
293+
// TestDDoSVulnerability_NonRepeatedFields tests that different fields don't trigger detection
294+
func TestDDoSVulnerability_NonRepeatedFields(t *testing.T) {
295+
// Create a query with many different fields - should not be blocked by repetition detection
296+
// Note: This will still be caught by MaxSelectionSetSize if configured
297+
queryWithDifferentFields := "query { a b c d e f g h i j k l m n o p q r s t u v w x y z }"
298+
299+
schema := graphql.MustParseSchema(simpleSchema, &simpleResolver{})
300+
301+
start := time.Now()
302+
result := schema.Exec(context.Background(), queryWithDifferentFields, "", nil)
303+
duration := time.Since(start)
304+
305+
t.Logf("Query with different fields took %v", duration)
306+
t.Logf("Errors: %v", result.Errors)
307+
308+
err := result.Errors[0]
309+
if !strings.Contains(err.Message, "Cannot query field") {
310+
t.Fatalf("Expected error about unknown fields, got: %s", err.Message)
311+
}
312+
}

internal/common/lexer.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ func (l *Lexer) Peek() rune {
5757
return l.next
5858
}
5959

60+
// TokenText returns the text of the current token without consuming it.
61+
func (l *Lexer) TokenText() string {
62+
return l.sc.TokenText()
63+
}
64+
6065
// ConsumeWhitespace consumes whitespace and tokens equivalent to whitespace (e.g. commas and comments).
6166
//
6267
// Consumed comment characters will build the description for the next type or field encountered.

internal/query/query.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,33 @@ func parseFragment(l *common.Lexer) *ast.FragmentDefinition {
9696
}
9797

9898
func parseSelectionSet(l *common.Lexer) []ast.Selection {
99+
const maxRepeatedTokens = 20 // Prevent DDoS attacks with repeated fields like "a a a a..."
100+
99101
var sels []ast.Selection
102+
var prevFieldName string
103+
repeatCount := 0
104+
100105
l.ConsumeToken('{')
101106
for l.Peek() != '}' {
107+
// Detect repeated identical field names to prevent DDoS attacks
108+
// Check the token text before consuming to detect patterns like "a a a a..."
109+
if l.Peek() == scanner.Ident {
110+
currentFieldName := l.TokenText()
111+
if currentFieldName == prevFieldName {
112+
repeatCount++
113+
if repeatCount >= maxRepeatedTokens {
114+
l.SyntaxError("invalid query")
115+
}
116+
} else {
117+
repeatCount = 0
118+
prevFieldName = currentFieldName
119+
}
120+
} else {
121+
// Reset for non-field tokens (fragments, etc.)
122+
repeatCount = 0
123+
prevFieldName = ""
124+
}
125+
102126
sels = append(sels, parseSelection(l))
103127
}
104128
l.ConsumeToken('}')

0 commit comments

Comments
 (0)