Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions src/main/java/graphql/schema/idl/SchemaTypeDirectivesChecker.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@
import graphql.schema.idl.errors.MissingTypeError;
import graphql.schema.idl.errors.NotAnInputTypeError;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static graphql.introspection.Introspection.DirectiveLocation.ARGUMENT_DEFINITION;
import static graphql.introspection.Introspection.DirectiveLocation.ENUM;
Expand Down Expand Up @@ -182,6 +186,10 @@ private static boolean isNoNullArgWithoutDefaultValue(InputValueDefinition defin
}

private void commonCheck(Collection<DirectiveDefinition> directiveDefinitions, List<GraphQLError> errors) {
List<DirectiveDefinition> directiveDefinitionList = new ArrayList<>(directiveDefinitions);
Map<String, DirectiveDefinition> directiveDefinitionMap = getByName(directiveDefinitionList, DirectiveDefinition::getName, mergeFirst());
Set<String> reportedCycles = new HashSet<>();

directiveDefinitions.forEach(directiveDefinition -> {
assertTypeName(directiveDefinition, errors);
directiveDefinition.getInputValueDefinitions().forEach(inputValueDefinition -> {
Expand All @@ -191,9 +199,96 @@ private void commonCheck(Collection<DirectiveDefinition> directiveDefinitions, L
errors.add(new DirectiveIllegalReferenceError(directiveDefinition, inputValueDefinition));
}
});

// Check for indirect cycles (A -> B -> A, or A -> B -> C -> A)
List<String> cyclePath = new ArrayList<>();
cyclePath.add(directiveDefinition.getName());
if (hasDirectiveCycle(directiveDefinition, directiveDefinitionMap, cyclePath, new LinkedHashSet<>())) {
// Only report each cycle once (use a canonical representation to avoid duplicates)
String cycleKey = getCycleKey(cyclePath);
if (!reportedCycles.contains(cycleKey)) {
reportedCycles.add(cycleKey);
errors.add(new DirectiveIllegalReferenceError(directiveDefinition, cyclePath));
}
}
});
}

/**
* Detects if a directive has a cycle through applied directives on its arguments.
*
* @param directiveDefinition the directive to check
* @param directiveDefinitionMap map of all directive definitions by name
* @param path the current path being explored (for error reporting)
* @param visited set of directive names currently in the recursion stack
* @return true if a cycle is detected
*/
private boolean hasDirectiveCycle(DirectiveDefinition directiveDefinition,
Map<String, DirectiveDefinition> directiveDefinitionMap,
List<String> path,
Set<String> visited) {
String directiveName = directiveDefinition.getName();

// If already in the current path, we have found a cycle
if (visited.contains(directiveName)) {
return true;
}

visited.add(directiveName);

// Check all input value definitions (arguments) of this directive
for (InputValueDefinition inputValueDefinition : directiveDefinition.getInputValueDefinitions()) {
// Get all directives applied to this argument
for (Directive appliedDirective : inputValueDefinition.getDirectives()) {
String appliedDirectiveName = appliedDirective.getName();

// Skip self-reference (already handled separately with more specific error)
if (appliedDirectiveName.equals(directiveName)) {
continue;
}

DirectiveDefinition referencedDirective = directiveDefinitionMap.get(appliedDirectiveName);
if (referencedDirective != null) {
path.add(appliedDirectiveName);
if (hasDirectiveCycle(referencedDirective, directiveDefinitionMap, path, visited)) {
return true;
}
path.remove(path.size() - 1);
}
}
}

visited.remove(directiveName);
return false;
}

/**
* Creates a canonical key for a cycle to avoid reporting the same cycle multiple times.
* The key is the smallest rotation of the cycle path (excluding the last element which is the same as the first).
*/
private static String getCycleKey(List<String> cyclePath) {
if (cyclePath.size() <= 1) {
return String.join("->", cyclePath);
}

// Remove the last element (same as first) for comparison
List<String> cycleWithoutLast = cyclePath.subList(0, cyclePath.size() - 1);

// Find the lexicographically smallest rotation
String smallest = String.join("->", cycleWithoutLast);
for (int i = 1; i < cycleWithoutLast.size(); i++) {
List<String> rotated = new ArrayList<>();
for (int j = 0; j < cycleWithoutLast.size(); j++) {
rotated.add(cycleWithoutLast.get((i + j) % cycleWithoutLast.size()));
}
String rotatedStr = String.join("->", rotated);
if (rotatedStr.compareTo(smallest) < 0) {
smallest = rotatedStr;
}
}
return smallest;
}

private static void assertTypeName(NamedNode<?> node, List<GraphQLError> errors) {
if (node.getName().length() >= 2 && node.getName().startsWith("__")) {
errors.add((new IllegalNameError(node)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import graphql.language.DirectiveDefinition;
import graphql.language.NamedNode;

import java.util.List;

@Internal
public class DirectiveIllegalReferenceError extends BaseError {
public DirectiveIllegalReferenceError(DirectiveDefinition directive, NamedNode location) {
Expand All @@ -12,4 +14,11 @@ public DirectiveIllegalReferenceError(DirectiveDefinition directive, NamedNode l
directive.getName(), location.getName(), lineCol(location)
));
}

public DirectiveIllegalReferenceError(DirectiveDefinition directive, List<String> cyclePath) {
super(directive,
String.format("'%s' forms a directive cycle via: %s '%s'",
directive.getName(), String.join(" -> ", cyclePath), lineCol(directive)
));
}
}
54 changes: 54 additions & 0 deletions src/test/groovy/graphql/schema/idl/SchemaGeneratorTest.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -2541,4 +2541,58 @@ class SchemaGeneratorTest extends Specification {
inputObjectType.isOneOf()
inputObjectType.hasAppliedDirective("oneOf")
}

def "mutually referencing directives (two-way cycle) throws SchemaProblem instead of StackOverflowError"() {
given:
def sdl = '''
directive @foo(x: Int @bar(y: 1)) on FIELD_DEFINITION | ARGUMENT_DEFINITION
directive @bar(y: Int @foo(x: 2)) on FIELD_DEFINITION | ARGUMENT_DEFINITION

type Query { field: String @foo(x: 10) @bar(y: 20) }
'''

when:
def registry = new SchemaParser().parse(sdl)
UnExecutableSchemaGenerator.makeUnExecutableSchema(registry)

then:
def ex = thrown(SchemaProblem)
ex.message.contains("forms a directive cycle via:")
}

def "three-way directive cycle throws SchemaProblem instead of StackOverflowError"() {
given:
def sdl = '''
directive @dirA(x: Int @dirB(y: 1)) on FIELD_DEFINITION | ARGUMENT_DEFINITION
directive @dirB(y: Int @dirC(z: 2)) on FIELD_DEFINITION | ARGUMENT_DEFINITION
directive @dirC(z: Int @dirA(x: 3)) on FIELD_DEFINITION | ARGUMENT_DEFINITION

type Query { field: String @dirA(x: 10) @dirB(y: 20) @dirC(z: 30) }
'''

when:
def registry = new SchemaParser().parse(sdl)
UnExecutableSchemaGenerator.makeUnExecutableSchema(registry)

then:
def ex = thrown(SchemaProblem)
ex.message.contains("forms a directive cycle via:")
}

def "directive self-reference still correctly throws SchemaProblem"() {
given:
def sdl = '''
directive @recursive(depth: Int @recursive(depth: 0)) on FIELD_DEFINITION | ARGUMENT_DEFINITION

type Query { field: String @recursive(depth: 5) }
'''

when:
def registry = new SchemaParser().parse(sdl)
UnExecutableSchemaGenerator.makeUnExecutableSchema(registry)

then:
def ex = thrown(SchemaProblem)
ex.message.contains("must not reference itself")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -337,4 +337,76 @@ class SchemaTypeDirectivesCheckerTest extends Specification {
then:
errors.size() == 0
}

def "two directives must not reference each other (two-way cycle)"() {
given:
def spec = '''
directive @foo(x: Int @bar(y: 1)) on FIELD_DEFINITION | ARGUMENT_DEFINITION
directive @bar(y: Int @foo(x: 2)) on FIELD_DEFINITION | ARGUMENT_DEFINITION

type Query {
f1 : String
}
'''
def registry = parse(spec)
def errors = []

when:
new SchemaTypeDirectivesChecker(registry, RuntimeWiring.newRuntimeWiring().build()).checkTypeDirectives(errors)

then:
errors.size() == 1
errors.get(0) instanceof DirectiveIllegalReferenceError
// The cycle path should be: bar -> foo -> bar (or foo -> bar -> foo)
def msg = errors.get(0).getMessage()
msg.contains("forms a directive cycle via:")
(msg.contains("bar -> foo -> bar") || msg.contains("foo -> bar -> foo"))
}

def "three directives must not form a cycle (three-way cycle)"() {
given:
def spec = '''
directive @dirA(x: Int @dirB(y: 1)) on FIELD_DEFINITION | ARGUMENT_DEFINITION
directive @dirB(y: Int @dirC(z: 2)) on FIELD_DEFINITION | ARGUMENT_DEFINITION
directive @dirC(z: Int @dirA(x: 3)) on FIELD_DEFINITION | ARGUMENT_DEFINITION

type Query {
f1 : String
}
'''
def registry = parse(spec)
def errors = []

when:
new SchemaTypeDirectivesChecker(registry, RuntimeWiring.newRuntimeWiring().build()).checkTypeDirectives(errors)

then:
errors.size() == 1
errors.get(0) instanceof DirectiveIllegalReferenceError
// The cycle path should include all three directives
def msg = errors.get(0).getMessage()
msg.contains("forms a directive cycle via:")
msg.contains("dirA") && msg.contains("dirB") && msg.contains("dirC")
}

def "directives referencing without cycles are allowed"() {
given:
def spec = '''
directive @leaf on ARGUMENT_DEFINITION
directive @foo(x: Int @leaf) on FIELD_DEFINITION | ARGUMENT_DEFINITION
directive @bar(y: Int @leaf) on FIELD_DEFINITION | ARGUMENT_DEFINITION

type Query {
f1 : String @foo(x: 1) @bar(y: 2)
}
'''
def registry = parse(spec)
def errors = []

when:
new SchemaTypeDirectivesChecker(registry, RuntimeWiring.newRuntimeWiring().build()).checkTypeDirectives(errors)

then:
errors.size() == 0
}
}