1010import graphql .schema .GraphQLInputObjectType ;
1111import graphql .schema .GraphQLInputType ;
1212import graphql .schema .GraphQLScalarType ;
13- import graphql .schema .GraphQLTypeUtil ;
1413import graphql .validation .rules .ValidationRuleEnvironment ;
14+ import graphql .validation .util .Util ;
1515
1616import java .lang .reflect .Array ;
1717import java .math .BigDecimal ;
1818import java .util .Collection ;
19- import java .util .Collections ;
2019import java .util .LinkedHashMap ;
2120import java .util .List ;
2221import java .util .Map ;
@@ -38,61 +37,79 @@ public String getName() {
3837 }
3938
4039 @ Override
41- public boolean appliesToArgument (GraphQLArgument argument , GraphQLFieldDefinition fieldDefinition , GraphQLFieldsContainer fieldsContainer ) {
42- boolean applicable = appliesToType (argument .getType ());
43- if (!applicable ) {
44- String argType = argument .getType ().getName ();
45- Assert .assertShouldNeverHappen ("The directive %s cannot be placed on arguments of type %s" , getName (), argType );
46- }
47- return true ;
48- }
49-
50- protected abstract boolean appliesToType (GraphQLInputType inputType );
51-
52-
53- @ Override
54- public List <GraphQLError > runValidation (ValidationRuleEnvironment ruleEnvironment ) {
55- return Collections .emptyList ();
40+ public boolean appliesToType (GraphQLArgument argument , GraphQLFieldDefinition fieldDefinition , GraphQLFieldsContainer fieldsContainer ) {
41+ return appliesToType (Util .unwrapNonNull (argument .getType ()));
5642 }
5743
58- protected boolean appliesToTypes (GraphQLInputType argumentType , GraphQLScalarType ... scalarTypes ) {
59- GraphQLInputType unwrappedType = unwrap ( argumentType );
44+ protected boolean isOneOfTheseTypes (GraphQLInputType inputType , GraphQLScalarType ... scalarTypes ) {
45+ GraphQLInputType unwrappedType = Util . unwrapNonNull ( inputType );
6046 for (GraphQLScalarType scalarType : scalarTypes ) {
6147 if (unwrappedType .getName ().equals (scalarType .getName ())) {
6248 return true ;
6349 }
6450 }
6551 return false ;
66-
6752 }
6853
69-
70- protected GraphQLDirective getArgDirective (ValidationRuleEnvironment ruleEnvironment , String name ) {
71- GraphQLDirective directive = ruleEnvironment .getArgument ().getDirective (name );
72- return Assert .assertNotNull (directive );
73- }
74-
75- protected int getIntArg (GraphQLDirective directive , String argName , int defaultValue ) {
54+ protected int getIntArg (GraphQLDirective directive , String argName ) {
7655 GraphQLArgument argument = directive .getArgument (argName );
7756 if (argument == null ) {
78- return defaultValue ;
57+ return assertExpectedArgType ( argName , "Int" ) ;
7958 }
8059 Number value = (Number ) argument .getValue ();
8160 if (value == null ) {
82- return defaultValue ;
61+ value = (Number ) argument .getDefaultValue ();
62+ if (value == null ) {
63+ return assertExpectedArgType (argName , "Int" );
64+ }
8365 }
8466 return value .intValue ();
8567 }
8668
87- protected String getStrArg (GraphQLDirective directive , String name ) {
88- return (String ) directive .getArgument (name ).getValue ();
69+ protected String getStrArg (GraphQLDirective directive , String argName ) {
70+ GraphQLArgument argument = directive .getArgument (argName );
71+ if (argument == null ) {
72+ return assertExpectedArgType (argName , "String" );
73+ }
74+ String value = (String ) argument .getValue ();
75+ if (value == null ) {
76+ value = (String ) argument .getDefaultValue ();
77+ if (value == null ) {
78+ return assertExpectedArgType (argName , "String" );
79+ }
80+ }
81+ return value ;
8982 }
9083
91- protected String getMessageTemplate (GraphQLDirective directive ) {
92- String msg = getStrArg (directive , "message" );
93- return Assert .assertNotNull (msg , "A validation directive MUST have a message argument with a default" );
84+ protected boolean getBoolArg (GraphQLDirective directive , String argName ) {
85+ GraphQLArgument argument = directive .getArgument (argName );
86+ if (argument == null ) {
87+ return assertExpectedArgType (argName , "Boolean" );
88+ }
89+ Object value = argument .getValue ();
90+ if (value == null ) {
91+ value = argument .getDefaultValue ();
92+ if (value == null ) {
93+ return assertExpectedArgType (argName , "Boolean" );
94+ }
95+ }
96+ return Boolean .parseBoolean (String .valueOf (value ));
9497 }
9598
99+ protected String getMessageTemplate (GraphQLDirective directive ) {
100+ String msg = null ;
101+ GraphQLArgument arg = directive .getArgument ("message" );
102+ if (arg != null ) {
103+ msg = (String ) arg .getValue ();
104+ if (msg == null ) {
105+ msg = (String ) arg .getDefaultValue ();
106+ }
107+ }
108+ if (msg == null ) {
109+ msg = "graphql.validation." + getName () + ".message" ;
110+ }
111+ return msg ;
112+ }
96113
97114 protected Map <String , Object > mkMessageParams (Object ... args ) {
98115 Assert .assertTrue (args .length % 2 == 0 , "You MUST pass in an even number of arguments" );
@@ -107,19 +124,14 @@ protected Map<String, Object> mkMessageParams(Object... args) {
107124 return params ;
108125 }
109126
110- protected GraphQLInputType unwrap (GraphQLInputType inputType ) {
111- return (GraphQLInputType ) GraphQLTypeUtil .unwrapAll (inputType );
112- }
113-
114127 protected List <GraphQLError > mkError (ValidationRuleEnvironment ruleEnvironment , GraphQLDirective directive , Map <String , Object > msgParams ) {
115128 String messageTemplate = getMessageTemplate (directive );
116129 GraphQLError error = ruleEnvironment .getInterpolator ().interpolate (messageTemplate , msgParams , ruleEnvironment );
117130 return singletonList (error );
118131 }
119132
120-
121133 protected boolean isStringOrListOrMap (GraphQLInputType argumentType ) {
122- GraphQLInputType unwrappedType = unwrap (argumentType );
134+ GraphQLInputType unwrappedType = Util . unwrapOneAndAllNonNull (argumentType );
123135 return Scalars .GraphQLString .equals (unwrappedType ) ||
124136 isList (argumentType ) ||
125137 (unwrappedType instanceof GraphQLInputObjectType );
@@ -131,25 +143,39 @@ protected Map asMap(Object value) {
131143 return (Map ) value ;
132144 }
133145
134- protected BigDecimal asBigDecimal (Object value ) {
146+ protected BigDecimal asBigDecimal (Object value ) throws NumberFormatException {
147+ if (value == null ) {
148+ return Assert .assertShouldNeverHappen ("Validation cant handle null objects BigDecimals" );
149+ }
135150 String bdStr = "" ;
136151 if (value instanceof Number ) {
137152 bdStr = value .toString ();
153+ } else if (value instanceof String ) {
154+ bdStr = value .toString ();
138155 } else {
139156 Assert .assertShouldNeverHappen ("Validation cant handle objects of type '%s' as BigDecimals" , value .getClass ().getSimpleName ());
140157 }
141158 return new BigDecimal (bdStr );
142159 }
143160
161+ protected boolean asBoolean (Object value ) {
162+ if (value == null ) {
163+ return Assert .assertShouldNeverHappen ("Validation cant handle null objects Booleans" );
164+ }
165+ if (value instanceof Boolean ) {
166+ return (Boolean ) value ;
167+ } else {
168+ return Assert .assertShouldNeverHappen ("Validation cant handle objects of type '%s' as Booleans" , value .getClass ().getSimpleName ());
169+ }
170+ }
144171
145- protected int getStringOrObjectOrMapLength (GraphQLInputType argType , Object argumentValue ) {
146- GraphQLInputType unwrappedType = unwrap (argType );
172+ protected int getStringOrObjectOrMapLength (GraphQLInputType inputType , Object argumentValue ) {
147173 int valLen ;
148174 if (argumentValue == null ) {
149175 valLen = 0 ;
150- } else if (Scalars .GraphQLString .equals (unwrappedType )) {
176+ } else if (Scalars .GraphQLString .equals (Util . unwrapNonNull ( inputType ) )) {
151177 valLen = String .valueOf (argumentValue ).length ();
152- } else if (isList (argType )) {
178+ } else if (isList (inputType )) {
153179 valLen = getListLength (argumentValue );
154180 } else {
155181 valLen = getObjectLen (argumentValue );
@@ -165,7 +191,6 @@ private int getObjectLen(Object value) {
165191 return map .size ();
166192 }
167193
168-
169194 private int getListLength (Object value ) {
170195 if (value instanceof Collection ) {
171196 return ((Collection ) value ).size ();
@@ -181,4 +206,8 @@ private int getListLength(Object value) {
181206 return 0 ;
182207 }
183208
209+ private <T > T assertExpectedArgType (String argName , String typeName ) {
210+ return Assert .assertShouldNeverHappen ("A validation directive MUST have a '%s' argument of type '%s' with a default value" , argName , typeName );
211+ }
212+
184213}
0 commit comments