Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: Tests

on:
push:
branches: [ master ]
pull_request:
branches: [ master ]

jobs:
test:
name: Test
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- name: Check out code
uses: actions/checkout@main

- name: Set up Go
uses: actions/setup-go@main
with:
go-version-file: go.mod

- name: Run tests
run: go test -v -race ./...
Comment thread Fixed
217 changes: 217 additions & 0 deletions ddos_vulnerability_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
package graphql_test

import (
"context"
"strings"
"testing"
"time"

"github.com/graph-gophers/graphql-go"
)

const simpleSchema = `
schema {
query: Query
}

type Query {
a: String
}
`

type simpleResolver struct{}

func (r *simpleResolver) A() *string {
val := "value"
return &val
}

// TestDDoSVulnerability_ManyFieldsAtSameLevel tests the vulnerability where
// a query with thousands of fields at the same level causes CPU overload.
// This test demonstrates the vulnerability and is skipped by default.
func TestDDoSVulnerability_ManyFieldsAtSameLevel(t *testing.T) {
t.Skip("Skipping vulnerability demonstration test - it would timeout without the fix")
// Create a query with many duplicate fields at the same level
// This is the attack vector from the user's report
numFields := 5000
fields := make([]string, numFields)
for i := 0; i < numFields; i++ {
fields[i] = "a"
}

maliciousQuery := "query { " + strings.Join(fields, " ") + " }"

schema := graphql.MustParseSchema(simpleSchema, &simpleResolver{})

// Set a timeout to prevent the test from hanging indefinitely
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

// This should complete quickly, but without a fix it will cause CPU overload
done := make(chan struct{})
go func() {
result := schema.Exec(ctx, maliciousQuery, "", nil)
// We expect either:
// 1. An error indicating the query is too complex (with fix)
// 2. Success (but it should be fast)
if result.Errors != nil {
t.Logf("Query returned errors (expected with fix): %v", result.Errors)
}
close(done)
}()

select {
case <-done:
t.Log("Query completed")
case <-ctx.Done():
t.Fatal("Query timed out - DDoS vulnerability confirmed")
}
}

// TestDDoSVulnerability_ExtremeCase tests an even more extreme case
// This test demonstrates the vulnerability and is skipped by default.
func TestDDoSVulnerability_ExtremeCase(t *testing.T) {
t.Skip("Skipping extreme vulnerability demonstration test - it would timeout without the fix")
// Create a query with an extreme number of fields (like the user's example)
// The user's query had roughly 100,000+ fields
numFields := 100000
fields := make([]string, numFields)
for i := 0; i < numFields; i++ {
fields[i] = "a"
}

maliciousQuery := "query { " + strings.Join(fields, " ") + " }"

schema := graphql.MustParseSchema(simpleSchema, &simpleResolver{})

// Set a strict timeout
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

done := make(chan struct{})
var testErr error
go func() {
start := time.Now()
result := schema.Exec(ctx, maliciousQuery, "", nil)
duration := time.Since(start)

t.Logf("Query took %v to complete", duration)

// With a fix, this should be rejected quickly (< 100ms)
// Without a fix, this will timeout
if duration > 1*time.Second {
testErr = nil // Will be caught by timeout
}

if result.Errors != nil {
t.Logf("Query returned errors: %v", result.Errors)
}
close(done)
}()

select {
case <-done:
if testErr != nil {
t.Fatal(testErr)
}
t.Log("Query completed (should be fast with fix)")
case <-ctx.Done():
t.Fatal("Query timed out - DDoS vulnerability confirmed. This query with 100k fields should be rejected immediately.")
}
}

// TestDDoSVulnerability_ValidationOnly tests that the validation phase itself is vulnerable
// This test demonstrates the vulnerability and is skipped by default.
func TestDDoSVulnerability_ValidationOnly(t *testing.T) {
t.Skip("Skipping validation-only vulnerability demonstration test - it would timeout without the fix")
// Test just the validation without execution
numFields := 10000
fields := make([]string, numFields)
for i := 0; i < numFields; i++ {
fields[i] = "a"
}

maliciousQuery := "query { " + strings.Join(fields, " ") + " }"

schema := graphql.MustParseSchema(simpleSchema, &simpleResolver{})

start := time.Now()
errors := schema.Validate(maliciousQuery)
duration := time.Since(start)

t.Logf("Validation took %v", duration)
t.Logf("Validation errors: %v", errors)

// Without a fix, validation can take seconds for 10k fields
// With a fix, it should be rejected immediately (< 100ms)
if duration > 500*time.Millisecond {
t.Errorf("Validation took too long (%v). This indicates a DDoS vulnerability in the validation phase.", duration)
}
}

// TestDDoSVulnerability_WithFix tests that the fix prevents the attack
func TestDDoSVulnerability_WithFix(t *testing.T) {
// Create a query with many fields
numFields := 10000
fields := make([]string, numFields)
for i := 0; i < numFields; i++ {
fields[i] = "a"
}

maliciousQuery := "query { " + strings.Join(fields, " ") + " }"

// Create schema with MaxSelectionSetSize limit
schema := graphql.MustParseSchema(simpleSchema, &simpleResolver{}, graphql.MaxSelectionSetSize(100))

start := time.Now()
errors := schema.Validate(maliciousQuery)
duration := time.Since(start)

t.Logf("Validation with fix took %v", duration)
t.Logf("Validation errors: %v", errors)

// With the fix, the query should be rejected immediately
if len(errors) == 0 {
t.Fatal("Expected validation errors, but got none")
}

// Check that the error message mentions the max selection set size
found := false
for _, err := range errors {
if strings.Contains(err.Message, "exceeds the maximum allowed size") {
found = true
break
}
}
if !found {
t.Errorf("Expected error about exceeding max selection set size, but got: %v", errors)
}

// Validation should be fast (< 100ms)
if duration > 100*time.Millisecond {
t.Errorf("Validation took too long (%v) even with the fix", duration)
}
}

// TestDDoSVulnerability_FixWithReasonableQuery tests that the fix doesn't break reasonable queries
func TestDDoSVulnerability_FixWithReasonableQuery(t *testing.T) {
// Create a reasonable query with just a few fields
reasonableQuery := "query { a a a a a }"

// Create schema with MaxSelectionSetSize limit
schema := graphql.MustParseSchema(simpleSchema, &simpleResolver{}, graphql.MaxSelectionSetSize(100))

errors := schema.Validate(reasonableQuery)

t.Logf("Validation errors for reasonable query: %v", errors)

// This should not be blocked
if len(errors) > 0 {
// Check if there's an error about max selection set size
for _, err := range errors {
if strings.Contains(err.Message, "exceeds the maximum allowed size") {
t.Errorf("Reasonable query was incorrectly blocked by MaxSelectionSetSize")
}
}
}
}
7 changes: 6 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
module github.com/graph-gophers/graphql-go

go 1.16
go 1.18

require (
github.com/opentracing/opentracing-go v1.2.0
go.opentelemetry.io/otel v1.6.3
go.opentelemetry.io/otel/trace v1.6.3
)

require (
github.com/go-logr/logr v1.2.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
)
14 changes: 12 additions & 2 deletions graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ type Schema struct {
directives []directives.Directive
maxQueryLength int
maxDepth int
maxSelectionSetSize int
maxParallelism int
rateLimiter ratelimit.RateLimiter
tracer tracer.Tracer
Expand Down Expand Up @@ -137,6 +138,15 @@ func MaxDepth(n int) SchemaOpt {
}
}

// MaxSelectionSetSize specifies the maximum number of selections (fields) allowed in a single selection set.
// This helps prevent DDoS attacks where an attacker sends queries with thousands of fields at the same level.
// The default is 0 which disables this check. A recommended value is 100-1000 depending on your schema complexity.
func MaxSelectionSetSize(n int) SchemaOpt {
return func(s *Schema) {
s.maxSelectionSetSize = n
}
}

// MaxParallelism specifies the maximum number of resolvers per request allowed to run in parallel. The default is 10.
func MaxParallelism(n int) SchemaOpt {
return func(s *Schema) {
Expand Down Expand Up @@ -270,7 +280,7 @@ func (s *Schema) ValidateWithVariables(queryString string, variables map[string]
return []*errors.QueryError{qErr}
}

return validation.Validate(s.schema, doc, variables, s.maxDepth)
return validation.Validate(s.schema, doc, variables, s.maxDepth, s.maxSelectionSetSize)
}

// Exec executes the given query with the schema's resolver. It panics if the schema was created
Expand All @@ -297,7 +307,7 @@ func (s *Schema) exec(ctx context.Context, queryString string, operationName str
}

validationFinish := s.validationTracer.TraceValidation(ctx)
errs := validation.Validate(s.schema, doc, variables, s.maxDepth)
errs := validation.Validate(s.schema, doc, variables, s.maxDepth, s.maxSelectionSetSize)
validationFinish(errs)
if len(errs) != 0 {
return &Response{Errors: errs}
Expand Down
4 changes: 2 additions & 2 deletions internal/validation/validate_max_depth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func (tc maxDepthTestCase) Run(t *testing.T, s *ast.Schema) {
t.Fatal(qErr)
}

errs := Validate(s, doc, nil, tc.depth)
errs := Validate(s, doc, nil, tc.depth, 0)
if len(tc.expectedErrors) > 0 {
if len(errs) > 0 {
for _, expected := range tc.expectedErrors {
Expand Down Expand Up @@ -489,7 +489,7 @@ func TestMaxDepthValidation(t *testing.T) {
t.Fatal(err)
}

context := newContext(s, doc, tc.maxDepth)
context := newContext(s, doc, tc.maxDepth, 0)
op := doc.Operations[0]

opc := &opContext{context: context, ops: doc.Operations}
Expand Down
Loading
Loading