|
1 # -*- coding: utf-8 -*- |
|
2 |
|
3 # Copyright (c) 2019 Detlev Offenbach <detlev@die-offenbachs.de> |
|
4 # |
|
5 |
|
6 """ |
|
7 Module implementing a checker for function type annotations. |
|
8 """ |
|
9 |
|
10 import sys |
|
11 import ast |
|
12 |
|
13 |
|
14 class AnnotationsChecker(object): |
|
15 """ |
|
16 Class implementing a checker for function type annotations. |
|
17 """ |
|
18 Codes = [ |
|
19 ## Function Annotations |
|
20 "A001", "A002", "A003", |
|
21 |
|
22 ## Method Annotations |
|
23 "A101", "A102", |
|
24 |
|
25 ## Return Annotations |
|
26 "A201", "A202", "A203", "A204", "A205", "A206", |
|
27 |
|
28 ## Syntax Error |
|
29 "A999", |
|
30 ] |
|
31 |
|
32 def __init__(self, source, filename, select, ignore, expected, repeat): |
|
33 """ |
|
34 Constructor |
|
35 |
|
36 @param source source code to be checked |
|
37 @type list of str |
|
38 @param filename name of the source file |
|
39 @type str |
|
40 @param select list of selected codes |
|
41 @type list of str |
|
42 @param ignore list of codes to be ignored |
|
43 @type list of str |
|
44 @param expected list of expected codes |
|
45 @type list of str |
|
46 @param repeat flag indicating to report each occurrence of a code |
|
47 @type bool |
|
48 """ |
|
49 self.__select = tuple(select) |
|
50 self.__ignore = ('',) if select else tuple(ignore) |
|
51 self.__expected = expected[:] |
|
52 self.__repeat = repeat |
|
53 self.__filename = filename |
|
54 self.__source = source[:] |
|
55 |
|
56 # statistics counters |
|
57 self.counters = {} |
|
58 |
|
59 # collection of detected errors |
|
60 self.errors = [] |
|
61 |
|
62 checkersWithCodes = [ |
|
63 ( |
|
64 self.__checkFunctionAnnotations, |
|
65 ("A001", "A002", "A003", "A101", "A102", |
|
66 "A201", "A202", "A203", "A204", "A205", "A206",) |
|
67 ), |
|
68 ] |
|
69 |
|
70 self.__checkers = [] |
|
71 for checker, codes in checkersWithCodes: |
|
72 if any(not (code and self.__ignoreCode(code)) |
|
73 for code in codes): |
|
74 self.__checkers.append(checker) |
|
75 |
|
76 def __ignoreCode(self, code): |
|
77 """ |
|
78 Private method to check if the message code should be ignored. |
|
79 |
|
80 @param code message code to check for |
|
81 @type str |
|
82 @return flag indicating to ignore the given code |
|
83 @rtype bool |
|
84 """ |
|
85 return (code.startswith(self.__ignore) and |
|
86 not code.startswith(self.__select)) |
|
87 |
|
88 def __error(self, lineNumber, offset, code, *args): |
|
89 """ |
|
90 Private method to record an issue. |
|
91 |
|
92 @param lineNumber line number of the issue |
|
93 @type int |
|
94 @param offset position within line of the issue |
|
95 @type int |
|
96 @param code message code |
|
97 @type str |
|
98 @param args arguments for the message |
|
99 @type list |
|
100 """ |
|
101 if self.__ignoreCode(code): |
|
102 return |
|
103 |
|
104 if code in self.counters: |
|
105 self.counters[code] += 1 |
|
106 else: |
|
107 self.counters[code] = 1 |
|
108 |
|
109 # Don't care about expected codes |
|
110 if code in self.__expected: |
|
111 return |
|
112 |
|
113 if code and (self.counters[code] == 1 or self.__repeat): |
|
114 # record the issue with one based line number |
|
115 self.errors.append( |
|
116 (self.__filename, lineNumber + 1, offset, (code, args))) |
|
117 |
|
118 def __reportInvalidSyntax(self): |
|
119 """ |
|
120 Private method to report a syntax error. |
|
121 """ |
|
122 exc_type, exc = sys.exc_info()[:2] |
|
123 if len(exc.args) > 1: |
|
124 offset = exc.args[1] |
|
125 if len(offset) > 2: |
|
126 offset = offset[1:3] |
|
127 else: |
|
128 offset = (1, 0) |
|
129 self.__error(offset[0] - 1, offset[1] or 0, |
|
130 'A999', exc_type.__name__, exc.args[0]) |
|
131 |
|
132 def __generateTree(self): |
|
133 """ |
|
134 Private method to generate an AST for our source. |
|
135 |
|
136 @return generated AST |
|
137 @rtype ast.Module |
|
138 """ |
|
139 source = "".join(self.__source) |
|
140 # Check type for py2: if not str it's unicode |
|
141 if sys.version_info[0] == 2: |
|
142 try: |
|
143 source = source.encode('utf-8') |
|
144 except UnicodeError: |
|
145 pass |
|
146 |
|
147 return compile(source, self.__filename, 'exec', ast.PyCF_ONLY_AST) |
|
148 |
|
149 def run(self): |
|
150 """ |
|
151 Public method to check the given source against annotation issues. |
|
152 """ |
|
153 if not self.__filename: |
|
154 # don't do anything, if essential data is missing |
|
155 return |
|
156 |
|
157 if not self.__checkers: |
|
158 # don't do anything, if no codes were selected |
|
159 return |
|
160 |
|
161 try: |
|
162 self.__tree = self.__generateTree() |
|
163 except (SyntaxError, TypeError): |
|
164 self.__reportInvalidSyntax() |
|
165 return |
|
166 |
|
167 for check in self.__checkers: |
|
168 check() |
|
169 |
|
170 def __checkFunctionAnnotations(self): |
|
171 """ |
|
172 Private method to check for function annotation issues. |
|
173 """ |
|
174 visitor = FunctionVisitor(self.__source) |
|
175 visitor.visit(self.__tree) |
|
176 for issue in visitor.issues: |
|
177 node = issue[0] |
|
178 reason = issue[1] |
|
179 params = issue[2:] |
|
180 self.__error(node.lineno - 1, node.col_offset, reason, *params) |
|
181 |
|
182 |
|
183 class FunctionVisitor(ast.NodeVisitor): |
|
184 """ |
|
185 Class implementing a node visitor to check function annotations. |
|
186 |
|
187 Note: this class is modelled after flake8-annotations checker. |
|
188 """ |
|
189 def __init__(self, sourceLines): |
|
190 """ |
|
191 Constructor |
|
192 |
|
193 @param sourceLines lines of source code |
|
194 @type list of str |
|
195 """ |
|
196 super(FunctionVisitor, self).__init__() |
|
197 |
|
198 self.__sourceLines = sourceLines |
|
199 |
|
200 self.issues = [] |
|
201 |
|
202 def visit_FunctionDef(self, node): |
|
203 """ |
|
204 Public method to handle a function or method definition. |
|
205 |
|
206 @param node reference to the node to be processed |
|
207 @type ast.FunctionDef |
|
208 """ |
|
209 self.__checkFunctionNode(node) |
|
210 self.generic_visit(node) |
|
211 |
|
212 def visit_AsyncFunctionDef(self, node): |
|
213 """ |
|
214 Public method to handle an async function or method definition. |
|
215 |
|
216 @param node reference to the node to be processed |
|
217 @type ast.AsyncFunctionDef |
|
218 """ |
|
219 self.__checkFunctionNode(node) |
|
220 self.generic_visit(node) |
|
221 |
|
222 def visit_ClassDef(self, node): |
|
223 """ |
|
224 Public method to handle class definitions. |
|
225 |
|
226 @param node reference to the node to be processed |
|
227 @type ast.ClassDef |
|
228 """ |
|
229 methodNodes = [ |
|
230 childNode for childNode in node.body |
|
231 if isinstance(childNode, (ast.FunctionDef, ast.AsyncFunctionDef)) |
|
232 ] |
|
233 for methodNode in methodNodes: |
|
234 self.__checkFunctionNode(methodNode, classMethod=True) |
|
235 |
|
236 def __checkFunctionNode(self, node, classMethod=False): |
|
237 """ |
|
238 Private method to check an individual function definition node. |
|
239 |
|
240 @param node reference to the node to be processed |
|
241 @type ast.FunctionDef or ast.AsyncFunctionDef |
|
242 @param classMethod flag indicating a class method |
|
243 @type bool |
|
244 """ |
|
245 if node.name.startswith("__") and node.name.endswith("__"): |
|
246 visibilityType = "special" |
|
247 elif node.name.startswith("__"): |
|
248 visibilityType = "private" |
|
249 elif node.name.startswith("_"): |
|
250 visibilityType = "protected" |
|
251 else: |
|
252 visibilityType = "public" |
|
253 |
|
254 if classMethod: |
|
255 decorators = [ |
|
256 decorator.id for decorator in node.decorator_list |
|
257 if isinstance(decorator, ast.Name) |
|
258 ] |
|
259 if "classmethod" in decorators: |
|
260 classMethodType = "decorator" |
|
261 elif "staticmethod" in decorators: |
|
262 classMethodType = "staticmethod" |
|
263 else: |
|
264 classMethodType = "" |
|
265 else: |
|
266 classMethodType = "function" |
|
267 |
|
268 # check argument annotations |
|
269 for argType in ("args", "vararg", "kwonlyargs", "kwarg"): |
|
270 args = node.args.__getattribute__(argType) |
|
271 if args: |
|
272 if not isinstance(args, list): |
|
273 args = [args] |
|
274 |
|
275 for arg in args: |
|
276 if not arg.annotation: |
|
277 self.__classifyArgumentError( |
|
278 arg, argType, classMethodType) |
|
279 |
|
280 # check function return annotation |
|
281 if not node.returns: |
|
282 lineno = node.body[0].lineno |
|
283 colOffset = self.__sourceLines[lineno - 1].find(":") + 1 |
|
284 self.__classifyReturnError(classMethodType, visibilityType, |
|
285 lineno, colOffset) |
|
286 |
|
287 def __classifyReturnError(self, methodType, visibilityType, lineno, |
|
288 colOffset): |
|
289 """ |
|
290 Private method to classify and record a return annotation issue. |
|
291 |
|
292 @param methodType type of method/function the argument belongs to |
|
293 @type str |
|
294 @param visibilityType visibility of the function |
|
295 @type str |
|
296 @param lineno line number |
|
297 @type int |
|
298 @param colOffset column number |
|
299 @type int |
|
300 """ |
|
301 # create a dummy AST node to report line and column |
|
302 node = ast.AST() |
|
303 node.lineno = lineno |
|
304 node.col_offset = colOffset |
|
305 |
|
306 # now classify the issue |
|
307 if methodType == "classmethod": |
|
308 self.issues.append((node, "A206")) |
|
309 elif methodType == "staticmethod": |
|
310 self.issues.append((node, "A205")) |
|
311 elif visibilityType == "special": |
|
312 self.issues.append((node, "A204")) |
|
313 elif visibilityType == "private": |
|
314 self.issues.append((node, "A203")) |
|
315 elif visibilityType == "protected": |
|
316 self.issues.append((node, "A202")) |
|
317 else: |
|
318 self.issues.append((node, "A201")) |
|
319 |
|
320 def __classifyArgumentError(self, argNode, argType, methodType): |
|
321 """ |
|
322 Private method to classify and record an argument annotation issue. |
|
323 |
|
324 @param argNode reference to the argument node |
|
325 @type ast.arguments |
|
326 @param argType type of the argument node |
|
327 @type str |
|
328 @param methodType type of method/function the argument belongs to |
|
329 @type str |
|
330 """ |
|
331 # check class method issues |
|
332 if methodType != "function": |
|
333 if argNode.arg in ("cls", "self"): |
|
334 if methodType == "classmethod": |
|
335 self.issues.append((argNode, "A102")) |
|
336 return |
|
337 elif methodType != "staticmethod": |
|
338 self.issues.append((argNode, "A101")) |
|
339 return |
|
340 |
|
341 # check all other arguments |
|
342 if argType == "kwarg": |
|
343 self.issues.append((argNode, "A003", argNode.arg)) |
|
344 elif argType == "vararg": |
|
345 self.issues.append((argNode, "A002", argNode.arg)) |
|
346 else: |
|
347 # args and kwonlyargs |
|
348 self.issues.append((argNode, "A001", argNode.arg)) |