|
1 # -*- coding: utf-8 -*- |
|
2 |
|
3 # Copyright (c) 2019 - 2020 Detlev Offenbach <detlev@die-offenbachs.de> |
|
4 # |
|
5 |
|
6 """ |
|
7 Module implementing a checker for function type annotations. |
|
8 """ |
|
9 |
|
10 import sys |
|
11 import ast |
|
12 |
|
13 import AstUtilities |
|
14 |
|
15 |
|
16 class AnnotationsChecker(object): |
|
17 """ |
|
18 Class implementing a checker for function type annotations. |
|
19 """ |
|
20 Codes = [ |
|
21 ## Function Annotations |
|
22 "A001", "A002", "A003", |
|
23 |
|
24 ## Method Annotations |
|
25 "A101", "A102", |
|
26 |
|
27 ## Return Annotations |
|
28 "A201", "A202", "A203", "A204", "A205", "A206", |
|
29 |
|
30 ## Annotation Coverage |
|
31 "A881", |
|
32 |
|
33 ## Annotation Complexity |
|
34 "A891", |
|
35 |
|
36 ## Syntax Error |
|
37 "A999", |
|
38 ] |
|
39 |
|
40 def __init__(self, source, filename, select, ignore, expected, repeat, |
|
41 args): |
|
42 """ |
|
43 Constructor |
|
44 |
|
45 @param source source code to be checked |
|
46 @type list of str |
|
47 @param filename name of the source file |
|
48 @type str |
|
49 @param select list of selected codes |
|
50 @type list of str |
|
51 @param ignore list of codes to be ignored |
|
52 @type list of str |
|
53 @param expected list of expected codes |
|
54 @type list of str |
|
55 @param repeat flag indicating to report each occurrence of a code |
|
56 @type bool |
|
57 @param args dictionary of arguments for the annotation checks |
|
58 @type dict |
|
59 """ |
|
60 self.__select = tuple(select) |
|
61 self.__ignore = ('',) if select else tuple(ignore) |
|
62 self.__expected = expected[:] |
|
63 self.__repeat = repeat |
|
64 self.__filename = filename |
|
65 self.__source = source[:] |
|
66 self.__args = args |
|
67 |
|
68 # statistics counters |
|
69 self.counters = {} |
|
70 |
|
71 # collection of detected errors |
|
72 self.errors = [] |
|
73 |
|
74 checkersWithCodes = [ |
|
75 ( |
|
76 self.__checkFunctionAnnotations, |
|
77 ("A001", "A002", "A003", "A101", "A102", |
|
78 "A201", "A202", "A203", "A204", "A205", "A206",) |
|
79 ), |
|
80 (self.__checkAnnotationsCoverage, ("A881",)), |
|
81 (self.__checkAnnotationComplexity, ("A891",)), |
|
82 ] |
|
83 |
|
84 self.__defaultArgs = { |
|
85 "MinimumCoverage": 75, # % of type annotation coverage |
|
86 "MaximumComplexity": 3, |
|
87 } |
|
88 |
|
89 self.__checkers = [] |
|
90 for checker, codes in checkersWithCodes: |
|
91 if any(not (code and self.__ignoreCode(code)) |
|
92 for code in codes): |
|
93 self.__checkers.append(checker) |
|
94 |
|
95 def __ignoreCode(self, code): |
|
96 """ |
|
97 Private method to check if the message code should be ignored. |
|
98 |
|
99 @param code message code to check for |
|
100 @type str |
|
101 @return flag indicating to ignore the given code |
|
102 @rtype bool |
|
103 """ |
|
104 return (code.startswith(self.__ignore) and |
|
105 not code.startswith(self.__select)) |
|
106 |
|
107 def __error(self, lineNumber, offset, code, *args): |
|
108 """ |
|
109 Private method to record an issue. |
|
110 |
|
111 @param lineNumber line number of the issue |
|
112 @type int |
|
113 @param offset position within line of the issue |
|
114 @type int |
|
115 @param code message code |
|
116 @type str |
|
117 @param args arguments for the message |
|
118 @type list |
|
119 """ |
|
120 if self.__ignoreCode(code): |
|
121 return |
|
122 |
|
123 if code in self.counters: |
|
124 self.counters[code] += 1 |
|
125 else: |
|
126 self.counters[code] = 1 |
|
127 |
|
128 # Don't care about expected codes |
|
129 if code in self.__expected: |
|
130 return |
|
131 |
|
132 if code and (self.counters[code] == 1 or self.__repeat): |
|
133 # record the issue with one based line number |
|
134 self.errors.append( |
|
135 { |
|
136 "file": self.__filename, |
|
137 "line": lineNumber + 1, |
|
138 "offset": offset, |
|
139 "code": code, |
|
140 "args": args, |
|
141 } |
|
142 ) |
|
143 |
|
144 def __reportInvalidSyntax(self): |
|
145 """ |
|
146 Private method to report a syntax error. |
|
147 """ |
|
148 exc_type, exc = sys.exc_info()[:2] |
|
149 if len(exc.args) > 1: |
|
150 offset = exc.args[1] |
|
151 if len(offset) > 2: |
|
152 offset = offset[1:3] |
|
153 else: |
|
154 offset = (1, 0) |
|
155 self.__error(offset[0] - 1, offset[1] or 0, |
|
156 'A999', exc_type.__name__, exc.args[0]) |
|
157 |
|
158 def __generateTree(self): |
|
159 """ |
|
160 Private method to generate an AST for our source. |
|
161 |
|
162 @return generated AST |
|
163 @rtype ast.Module |
|
164 """ |
|
165 source = "".join(self.__source) |
|
166 return compile(source, self.__filename, 'exec', ast.PyCF_ONLY_AST) |
|
167 |
|
168 def run(self): |
|
169 """ |
|
170 Public method to check the given source against annotation issues. |
|
171 """ |
|
172 if not self.__filename: |
|
173 # don't do anything, if essential data is missing |
|
174 return |
|
175 |
|
176 if not self.__checkers: |
|
177 # don't do anything, if no codes were selected |
|
178 return |
|
179 |
|
180 try: |
|
181 self.__tree = self.__generateTree() |
|
182 except (SyntaxError, TypeError): |
|
183 self.__reportInvalidSyntax() |
|
184 return |
|
185 |
|
186 for check in self.__checkers: |
|
187 check() |
|
188 |
|
189 def __checkFunctionAnnotations(self): |
|
190 """ |
|
191 Private method to check for function annotation issues. |
|
192 """ |
|
193 visitor = FunctionVisitor(self.__source) |
|
194 visitor.visit(self.__tree) |
|
195 for issue in visitor.issues: |
|
196 node = issue[0] |
|
197 reason = issue[1] |
|
198 params = issue[2:] |
|
199 self.__error(node.lineno - 1, node.col_offset, reason, *params) |
|
200 |
|
201 def __checkAnnotationsCoverage(self): |
|
202 """ |
|
203 Private method to check for function annotation coverage. |
|
204 """ |
|
205 minAnnotationsCoverage = self.__args.get( |
|
206 "MinimumCoverage", self.__defaultArgs["MinimumCoverage"]) |
|
207 if minAnnotationsCoverage == 0: |
|
208 # 0 means it is switched off |
|
209 return |
|
210 |
|
211 functionDefs = [ |
|
212 f for f in ast.walk(self.__tree) |
|
213 if isinstance(f, (ast.AsyncFunctionDef, ast.FunctionDef)) |
|
214 ] |
|
215 if not functionDefs: |
|
216 # no functions/methods at all |
|
217 return |
|
218 |
|
219 functionDefAnnotationsInfo = [ |
|
220 hasTypeAnnotations(f) for f in functionDefs |
|
221 ] |
|
222 annotationsCoverage = int( |
|
223 len(list(filter(None, functionDefAnnotationsInfo))) / |
|
224 len(functionDefAnnotationsInfo) * 100 |
|
225 ) |
|
226 if annotationsCoverage < minAnnotationsCoverage: |
|
227 self.__error(0, 0, "A881", annotationsCoverage) |
|
228 |
|
229 def __checkAnnotationComplexity(self): |
|
230 """ |
|
231 Private method to check the type annotation complexity. |
|
232 """ |
|
233 maxAnnotationComplexity = self.__args.get( |
|
234 "MaximumComplexity", self.__defaultArgs["MaximumComplexity"]) |
|
235 typeAnnotations = [] |
|
236 |
|
237 functionDefs = [ |
|
238 f for f in ast.walk(self.__tree) |
|
239 if isinstance(f, (ast.AsyncFunctionDef, ast.FunctionDef)) |
|
240 ] |
|
241 for functionDef in functionDefs: |
|
242 typeAnnotations += list(filter( |
|
243 None, [a.annotation for a in functionDef.args.args])) |
|
244 if functionDef.returns: |
|
245 typeAnnotations.append(functionDef.returns) |
|
246 typeAnnotations += [a.annotation for a in ast.walk(self.__tree) |
|
247 if isinstance(a, ast.AnnAssign) and a.annotation] |
|
248 for annotation in typeAnnotations: |
|
249 complexity = getAnnotationComplexity(annotation) |
|
250 if complexity > maxAnnotationComplexity: |
|
251 self.__error(annotation.lineno - 1, annotation.col_offset, |
|
252 "A891", complexity, maxAnnotationComplexity) |
|
253 |
|
254 |
|
255 class FunctionVisitor(ast.NodeVisitor): |
|
256 """ |
|
257 Class implementing a node visitor to check function annotations. |
|
258 |
|
259 Note: this class is modelled after flake8-annotations checker. |
|
260 """ |
|
261 def __init__(self, sourceLines): |
|
262 """ |
|
263 Constructor |
|
264 |
|
265 @param sourceLines lines of source code |
|
266 @type list of str |
|
267 """ |
|
268 super(FunctionVisitor, self).__init__() |
|
269 |
|
270 self.__sourceLines = sourceLines |
|
271 |
|
272 self.issues = [] |
|
273 |
|
274 def visit_FunctionDef(self, node): |
|
275 """ |
|
276 Public method to handle a function or method definition. |
|
277 |
|
278 @param node reference to the node to be processed |
|
279 @type ast.FunctionDef |
|
280 """ |
|
281 self.__checkFunctionNode(node) |
|
282 self.generic_visit(node) |
|
283 |
|
284 def visit_AsyncFunctionDef(self, node): |
|
285 """ |
|
286 Public method to handle an async function or method definition. |
|
287 |
|
288 @param node reference to the node to be processed |
|
289 @type ast.AsyncFunctionDef |
|
290 """ |
|
291 self.__checkFunctionNode(node) |
|
292 self.generic_visit(node) |
|
293 |
|
294 def visit_ClassDef(self, node): |
|
295 """ |
|
296 Public method to handle class definitions. |
|
297 |
|
298 @param node reference to the node to be processed |
|
299 @type ast.ClassDef |
|
300 """ |
|
301 methodNodes = [ |
|
302 childNode for childNode in node.body |
|
303 if isinstance(childNode, (ast.FunctionDef, ast.AsyncFunctionDef)) |
|
304 ] |
|
305 for methodNode in methodNodes: |
|
306 self.__checkFunctionNode(methodNode, classMethod=True) |
|
307 |
|
308 def __checkFunctionNode(self, node, classMethod=False): |
|
309 """ |
|
310 Private method to check an individual function definition node. |
|
311 |
|
312 @param node reference to the node to be processed |
|
313 @type ast.FunctionDef or ast.AsyncFunctionDef |
|
314 @param classMethod flag indicating a class method |
|
315 @type bool |
|
316 """ |
|
317 if node.name.startswith("__") and node.name.endswith("__"): |
|
318 visibilityType = "special" |
|
319 elif node.name.startswith("__"): |
|
320 visibilityType = "private" |
|
321 elif node.name.startswith("_"): |
|
322 visibilityType = "protected" |
|
323 else: |
|
324 visibilityType = "public" |
|
325 |
|
326 if classMethod: |
|
327 decorators = [ |
|
328 decorator.id for decorator in node.decorator_list |
|
329 if isinstance(decorator, ast.Name) |
|
330 ] |
|
331 if "classmethod" in decorators: |
|
332 classMethodType = "decorator" |
|
333 elif "staticmethod" in decorators: |
|
334 classMethodType = "staticmethod" |
|
335 else: |
|
336 classMethodType = "" |
|
337 else: |
|
338 classMethodType = "function" |
|
339 |
|
340 # check argument annotations |
|
341 for argType in ("args", "vararg", "kwonlyargs", "kwarg"): |
|
342 args = node.args.__getattribute__(argType) |
|
343 if args: |
|
344 if not isinstance(args, list): |
|
345 args = [args] |
|
346 |
|
347 for arg in args: |
|
348 if not arg.annotation: |
|
349 self.__classifyArgumentError( |
|
350 arg, argType, classMethodType) |
|
351 |
|
352 # check function return annotation |
|
353 if not node.returns: |
|
354 lineno = node.lineno |
|
355 colOffset = self.__sourceLines[lineno - 1].rfind(":") + 1 |
|
356 self.__classifyReturnError(classMethodType, visibilityType, |
|
357 lineno, colOffset) |
|
358 |
|
359 def __classifyReturnError(self, methodType, visibilityType, lineno, |
|
360 colOffset): |
|
361 """ |
|
362 Private method to classify and record a return annotation issue. |
|
363 |
|
364 @param methodType type of method/function the argument belongs to |
|
365 @type str |
|
366 @param visibilityType visibility of the function |
|
367 @type str |
|
368 @param lineno line number |
|
369 @type int |
|
370 @param colOffset column number |
|
371 @type int |
|
372 """ |
|
373 # create a dummy AST node to report line and column |
|
374 node = ast.AST() |
|
375 node.lineno = lineno |
|
376 node.col_offset = colOffset |
|
377 |
|
378 # now classify the issue |
|
379 if methodType == "classmethod": |
|
380 self.issues.append((node, "A206")) |
|
381 elif methodType == "staticmethod": |
|
382 self.issues.append((node, "A205")) |
|
383 elif visibilityType == "special": |
|
384 self.issues.append((node, "A204")) |
|
385 elif visibilityType == "private": |
|
386 self.issues.append((node, "A203")) |
|
387 elif visibilityType == "protected": |
|
388 self.issues.append((node, "A202")) |
|
389 else: |
|
390 self.issues.append((node, "A201")) |
|
391 |
|
392 def __classifyArgumentError(self, argNode, argType, methodType): |
|
393 """ |
|
394 Private method to classify and record an argument annotation issue. |
|
395 |
|
396 @param argNode reference to the argument node |
|
397 @type ast.arguments |
|
398 @param argType type of the argument node |
|
399 @type str |
|
400 @param methodType type of method/function the argument belongs to |
|
401 @type str |
|
402 """ |
|
403 # check class method issues |
|
404 if methodType != "function": |
|
405 if argNode.arg in ("cls", "self"): |
|
406 if methodType == "classmethod": |
|
407 self.issues.append((argNode, "A102")) |
|
408 return |
|
409 elif methodType != "staticmethod": |
|
410 self.issues.append((argNode, "A101")) |
|
411 return |
|
412 |
|
413 # check all other arguments |
|
414 if argType == "kwarg": |
|
415 self.issues.append((argNode, "A003", argNode.arg)) |
|
416 elif argType == "vararg": |
|
417 self.issues.append((argNode, "A002", argNode.arg)) |
|
418 else: |
|
419 # args and kwonlyargs |
|
420 self.issues.append((argNode, "A001", argNode.arg)) |
|
421 |
|
422 ###################################################################### |
|
423 ## some utility functions below |
|
424 ###################################################################### |
|
425 |
|
426 |
|
427 def hasTypeAnnotations(funcNode): |
|
428 """ |
|
429 Function to check for type annotations. |
|
430 |
|
431 @param funcNode reference to the function definition node to be checked |
|
432 @type ast.AsyncFunctionDef or ast.FunctionDef |
|
433 @return flag indicating the presence of type annotations |
|
434 @rtype bool |
|
435 """ |
|
436 hasReturnAnnotation = funcNode.returns is not None |
|
437 hasArgsAnnotations = any(a for a in funcNode.args.args |
|
438 if a.annotation is not None) |
|
439 hasKwargsAnnotations = (funcNode.args and |
|
440 funcNode.args.kwarg and |
|
441 funcNode.args.kwarg.annotation is not None) |
|
442 hasKwonlyargsAnnotations = any(a for a in funcNode.args.kwonlyargs |
|
443 if a.annotation is not None) |
|
444 |
|
445 return any((hasReturnAnnotation, hasArgsAnnotations, hasKwargsAnnotations, |
|
446 hasKwonlyargsAnnotations)) |
|
447 |
|
448 |
|
449 def getAnnotationComplexity(annotationNode): |
|
450 """ |
|
451 Function to determine the annotation complexity. |
|
452 |
|
453 @param annotationNode reference to the node to determine the annotation |
|
454 complexity for |
|
455 @type ast.AST |
|
456 @return annotation complexity |
|
457 @rtype = int |
|
458 """ |
|
459 if AstUtilities.isString(annotationNode): |
|
460 annotationNode = ast.parse(annotationNode.s).body[0].value |
|
461 if isinstance(annotationNode, ast.Subscript): |
|
462 return 1 + getAnnotationComplexity(annotationNode.slice.value) |
|
463 if isinstance(annotationNode, ast.Tuple): |
|
464 return max(getAnnotationComplexity(n) for n in annotationNode.elts) |
|
465 return 1 |