|
1 # -*- coding: utf-8 -*- |
|
2 |
|
3 # Copyright (c) 2021 - 2022 Detlev Offenbach <detlev@die-offenbachs.de> |
|
4 # |
|
5 |
|
6 """ |
|
7 Module implementing a checker for import statements. |
|
8 """ |
|
9 |
|
10 import ast |
|
11 import copy |
|
12 import sys |
|
13 |
|
14 |
|
15 class ImportsChecker: |
|
16 """ |
|
17 Class implementing a checker for import statements. |
|
18 """ |
|
19 Codes = [ |
|
20 ## Local imports |
|
21 "I101", "I102", "I103", |
|
22 |
|
23 ## Imports order |
|
24 "I201", "I202", "I203", "I204", |
|
25 |
|
26 ## Various other import related |
|
27 "I901", "I902", "I903", "I904", |
|
28 ] |
|
29 |
|
30 def __init__(self, source, filename, tree, select, ignore, expected, |
|
31 repeat, args): |
|
32 """ |
|
33 Constructor |
|
34 |
|
35 @param source source code to be checked |
|
36 @type list of str |
|
37 @param filename name of the source file |
|
38 @type str |
|
39 @param tree AST tree of the source code |
|
40 @type ast.Module |
|
41 @param select list of selected codes |
|
42 @type list of str |
|
43 @param ignore list of codes to be ignored |
|
44 @type list of str |
|
45 @param expected list of expected codes |
|
46 @type list of str |
|
47 @param repeat flag indicating to report each occurrence of a code |
|
48 @type bool |
|
49 @param args dictionary of arguments for the various checks |
|
50 @type dict |
|
51 """ |
|
52 self.__select = tuple(select) |
|
53 self.__ignore = ("",) if select else tuple(ignore) |
|
54 self.__expected = expected[:] |
|
55 self.__repeat = repeat |
|
56 self.__filename = filename |
|
57 self.__source = source[:] |
|
58 self.__tree = copy.deepcopy(tree) |
|
59 self.__args = args |
|
60 |
|
61 # statistics counters |
|
62 self.counters = {} |
|
63 |
|
64 # collection of detected errors |
|
65 self.errors = [] |
|
66 |
|
67 checkersWithCodes = [ |
|
68 (self.__checkLocalImports, ("I101", "I102", "I103")), |
|
69 (self.__checkImportOrder, ("I201", "I202", "I203", "I204")), |
|
70 (self.__tidyImports, ("I901", "I902", "I903", "I904")), |
|
71 ] |
|
72 |
|
73 self.__checkers = [] |
|
74 for checker, codes in checkersWithCodes: |
|
75 if any(not (code and self.__ignoreCode(code)) |
|
76 for code in codes): |
|
77 self.__checkers.append(checker) |
|
78 |
|
79 def __ignoreCode(self, code): |
|
80 """ |
|
81 Private method to check if the message code should be ignored. |
|
82 |
|
83 @param code message code to check for |
|
84 @type str |
|
85 @return flag indicating to ignore the given code |
|
86 @rtype bool |
|
87 """ |
|
88 return (code.startswith(self.__ignore) and |
|
89 not code.startswith(self.__select)) |
|
90 |
|
91 def __error(self, lineNumber, offset, code, *args): |
|
92 """ |
|
93 Private method to record an issue. |
|
94 |
|
95 @param lineNumber line number of the issue |
|
96 @type int |
|
97 @param offset position within line of the issue |
|
98 @type int |
|
99 @param code message code |
|
100 @type str |
|
101 @param args arguments for the message |
|
102 @type list |
|
103 """ |
|
104 if self.__ignoreCode(code): |
|
105 return |
|
106 |
|
107 if code in self.counters: |
|
108 self.counters[code] += 1 |
|
109 else: |
|
110 self.counters[code] = 1 |
|
111 |
|
112 # Don't care about expected codes |
|
113 if code in self.__expected: |
|
114 return |
|
115 |
|
116 if code and (self.counters[code] == 1 or self.__repeat): |
|
117 # record the issue with one based line number |
|
118 self.errors.append( |
|
119 { |
|
120 "file": self.__filename, |
|
121 "line": lineNumber + 1, |
|
122 "offset": offset, |
|
123 "code": code, |
|
124 "args": args, |
|
125 } |
|
126 ) |
|
127 |
|
128 def run(self): |
|
129 """ |
|
130 Public method to check the given source against miscellaneous |
|
131 conditions. |
|
132 """ |
|
133 if not self.__filename: |
|
134 # don't do anything, if essential data is missing |
|
135 return |
|
136 |
|
137 if not self.__checkers: |
|
138 # don't do anything, if no codes were selected |
|
139 return |
|
140 |
|
141 for check in self.__checkers: |
|
142 check() |
|
143 |
|
144 def getStandardModules(self): |
|
145 """ |
|
146 Public method to get a list of modules of the standard library. |
|
147 |
|
148 @return set of builtin modules |
|
149 @rtype set of str |
|
150 """ |
|
151 try: |
|
152 return sys.stdlib_module_names |
|
153 except AttributeError: |
|
154 return { |
|
155 "__future__", "__main__", "_dummy_thread", "_thread", "abc", |
|
156 "aifc", "argparse", "array", "ast", "asynchat", "asyncio", |
|
157 "asyncore", "atexit", "audioop", "base64", "bdb", "binascii", |
|
158 "binhex", "bisect", "builtins", "bz2", "calendar", "cgi", |
|
159 "cgitb", "chunk", "cmath", "cmd", "code", "codecs", "codeop", |
|
160 "collections", "colorsys", "compileall", "concurrent", |
|
161 "configparser", "contextlib", "contextvars", "copy", "copyreg", |
|
162 "cProfile", "crypt", "csv", "ctypes", "curses", "dataclasses", |
|
163 "datetime", "dbm", "decimal", "difflib", "dis", "distutils", |
|
164 "doctest", "dummy_threading", "email", "encodings", |
|
165 "ensurepip", "enum", "errno", "faulthandler", "fcntl", |
|
166 "filecmp", "fileinput", "fnmatch", "formatter", "fractions", |
|
167 "ftplib", "functools", "gc", "getopt", "getpass", "gettext", |
|
168 "glob", "grp", "gzip", "hashlib", "heapq", "hmac", "html", |
|
169 "http", "imaplib", "imghdr", "imp", "importlib", "inspect", |
|
170 "io", "ipaddress", "itertools", "json", "keyword", "lib2to3", |
|
171 "linecache", "locale", "logging", "lzma", "mailbox", "mailcap", |
|
172 "marshal", "math", "mimetypes", "mmap", "modulefinder", |
|
173 "msilib", "msvcrt", "multiprocessing", "netrc", "nis", |
|
174 "nntplib", "numbers", "operator", "optparse", "os", |
|
175 "ossaudiodev", "parser", "pathlib", "pdb", "pickle", |
|
176 "pickletools", "pipes", "pkgutil", "platform", "plistlib", |
|
177 "poplib", "posix", "pprint", "profile", "pstats", "pty", "pwd", |
|
178 "py_compile", "pyclbr", "pydoc", "queue", "quopri", "random", |
|
179 "re", "readline", "reprlib", "resource", "rlcompleter", |
|
180 "runpy", "sched", "secrets", "select", "selectors", "shelve", |
|
181 "shlex", "shutil", "signal", "site", "smtpd", "smtplib", |
|
182 "sndhdr", "socket", "socketserver", "spwd", "sqlite3", "ssl", |
|
183 "stat", "statistics", "string", "stringprep", "struct", |
|
184 "subprocess", "sunau", "symbol", "symtable", "sys", |
|
185 "sysconfig", "syslog", "tabnanny", "tarfile", "telnetlib", |
|
186 "tempfile", "termios", "test", "textwrap", "threading", "time", |
|
187 "timeit", "tkinter", "token", "tokenize", "trace", "traceback", |
|
188 "tracemalloc", "tty", "turtle", "turtledemo", "types", |
|
189 "typing", "unicodedata", "unittest", "urllib", "uu", "uuid", |
|
190 "venv", "warnings", "wave", "weakref", "webbrowser", "winreg", |
|
191 "winsound", "wsgiref", "xdrlib", "xml", "xmlrpc", "zipapp", |
|
192 "zipfile", "zipimport", "zlib", "zoneinfo", |
|
193 } |
|
194 |
|
195 ####################################################################### |
|
196 ## Local imports |
|
197 ## |
|
198 ## adapted from: flake8-local-import v1.0.6 |
|
199 ####################################################################### |
|
200 |
|
201 def __checkLocalImports(self): |
|
202 """ |
|
203 Private method to check local imports. |
|
204 """ |
|
205 from .LocalImportVisitor import LocalImportVisitor |
|
206 |
|
207 visitor = LocalImportVisitor(self.__args, self) |
|
208 visitor.visit(copy.deepcopy(self.__tree)) |
|
209 for violation in visitor.violations: |
|
210 if not self.__ignoreCode(violation[1]): |
|
211 node = violation[0] |
|
212 reason = violation[1] |
|
213 self.__error(node.lineno - 1, node.col_offset, reason) |
|
214 |
|
215 ####################################################################### |
|
216 ## Import order |
|
217 ## |
|
218 ## adapted from: flake8-alphabetize v0.0.17 |
|
219 ####################################################################### |
|
220 |
|
221 def __checkImportOrder(self): |
|
222 """ |
|
223 Private method to check the order of import statements. |
|
224 """ |
|
225 from .ImportNode import ImportNode |
|
226 |
|
227 errors = [] |
|
228 imports = [] |
|
229 importNodes, listNode = self.__findNodes(self.__tree) |
|
230 |
|
231 # check for an error in '__all__' |
|
232 allError = self.__findErrorInAll(listNode) |
|
233 if allError is not None: |
|
234 errors.append(allError) |
|
235 |
|
236 for importNode in importNodes: |
|
237 if ( |
|
238 isinstance(importNode, ast.Import) and |
|
239 len(importNode.names) > 1 |
|
240 ): |
|
241 # skip suck imports because its already handled by pycodestyle |
|
242 continue |
|
243 |
|
244 imports.append(ImportNode( |
|
245 self.__args.get("ApplicationPackageNames", []), |
|
246 importNode, self)) |
|
247 |
|
248 lenImports = len(imports) |
|
249 if lenImports > 0: |
|
250 p = imports[0] |
|
251 if p.error is not None: |
|
252 errors.append(p.error) |
|
253 |
|
254 if lenImports > 1: |
|
255 for n in imports[1:]: |
|
256 if n.error is not None: |
|
257 errors.append(n.error) |
|
258 |
|
259 if n == p: |
|
260 errors.append((n.node, "I203", str(p), str(n))) |
|
261 elif n < p: |
|
262 errors.append((n.node, "I201", str(n), str(p))) |
|
263 |
|
264 p = n |
|
265 |
|
266 for error in errors: |
|
267 if not self.__ignoreCode(error[1]): |
|
268 node = error[0] |
|
269 reason = error[1] |
|
270 args = error[2:] |
|
271 self.__error(node.lineno - 1, node.col_offset, reason, *args) |
|
272 |
|
273 def __findNodes(self, tree): |
|
274 """ |
|
275 Private method to find all import and import from nodes of the given |
|
276 tree. |
|
277 |
|
278 @param tree reference to the ast node tree to be parsed |
|
279 @type ast.AST |
|
280 @return tuple containing a list of import nodes and the '__all__' node |
|
281 @rtype tuple of (ast.Import | ast.ImportFrom, ast.List | ast.Tuple) |
|
282 """ |
|
283 importNodes = [] |
|
284 listNode = None |
|
285 |
|
286 if isinstance(tree, ast.Module): |
|
287 body = tree.body |
|
288 |
|
289 for n in body: |
|
290 if isinstance(n, (ast.Import, ast.ImportFrom)): |
|
291 importNodes.append(n) |
|
292 |
|
293 elif isinstance(n, ast.Assign): |
|
294 for t in n.targets: |
|
295 if isinstance(t, ast.Name) and t.id == "__all__": |
|
296 value = n.value |
|
297 |
|
298 if isinstance(value, (ast.List, ast.Tuple)): |
|
299 listNode = value |
|
300 |
|
301 return importNodes, listNode |
|
302 |
|
303 def __findErrorInAll(self, node): |
|
304 """ |
|
305 Private method to check the '__all__' node for errors. |
|
306 |
|
307 @param node reference to the '__all__' node |
|
308 @type ast.List or ast.Tuple |
|
309 @return tuple containing a reference to the node and an error code |
|
310 @rtype rtype tuple of (ast.List | ast.Tuple, str) |
|
311 """ |
|
312 if node is not None: |
|
313 actualList = [] |
|
314 for el in node.elts: |
|
315 if isinstance(el, ast.Constant): |
|
316 actualList.append(el.value) |
|
317 elif isinstance(el, ast.Str): |
|
318 actualList.append(el.s) |
|
319 else: |
|
320 # Can't handle anything that isn't a string literal |
|
321 return None |
|
322 |
|
323 expectedList = sorted(actualList) |
|
324 if expectedList != actualList: |
|
325 return (node, "I204", ", ".join(expectedList)) |
|
326 |
|
327 return None |
|
328 |
|
329 ####################################################################### |
|
330 ## Tidy imports |
|
331 ## |
|
332 ## adapted from: flake8-tidy-imports v4.5.0 |
|
333 ####################################################################### |
|
334 |
|
335 def __tidyImports(self): |
|
336 """ |
|
337 Private method to check various other import related topics. |
|
338 """ |
|
339 self.__bannedModules = self.__args.get("BannedModules", []) |
|
340 self.__banRelativeImports = self.__args.get("BanRelativeImports", "") |
|
341 |
|
342 ruleMethods = [] |
|
343 if not self.__ignoreCode("I901"): |
|
344 ruleMethods.append(self.__checkUnnecessaryAlias) |
|
345 if ( |
|
346 not self.__ignoreCode("I902") and |
|
347 bool(self.__bannedModules) |
|
348 ): |
|
349 ruleMethods.append(self.__checkBannedImport) |
|
350 if ( |
|
351 (not self.__ignoreCode("I903") and |
|
352 self.__banRelativeImports == "parents") or |
|
353 (not self.__ignoreCode("I904") and |
|
354 self.__banRelativeImports == "true") |
|
355 ): |
|
356 ruleMethods.append(self.__checkBannedRelativeImports) |
|
357 |
|
358 for node in ast.walk(self.__tree): |
|
359 for method in ruleMethods: |
|
360 method(node) |
|
361 |
|
362 def __checkUnnecessaryAlias(self, node): |
|
363 """ |
|
364 Private method to check unnecessary import aliases. |
|
365 |
|
366 @param node reference to the node to be checked |
|
367 @type ast.AST |
|
368 """ |
|
369 if isinstance(node, ast.Import): |
|
370 for alias in node.names: |
|
371 if "." not in alias.name: |
|
372 fromName = None |
|
373 importedName = alias.name |
|
374 else: |
|
375 fromName, importedName = alias.name.rsplit(".", 1) |
|
376 |
|
377 if importedName == alias.asname: |
|
378 if fromName: |
|
379 rewritten = "from {0} import {1}".format( |
|
380 fromName, importedName) |
|
381 else: |
|
382 rewritten = "import {0}".format(importedName) |
|
383 |
|
384 self.__error(node.lineno - 1, node.col_offset, "I901", |
|
385 rewritten) |
|
386 |
|
387 elif isinstance(node, ast.ImportFrom): |
|
388 for alias in node.names: |
|
389 if alias.name == alias.asname: |
|
390 rewritten = "from {0} import {1}".format( |
|
391 node.module, alias.name) |
|
392 |
|
393 self.__error(node.lineno - 1, node.col_offset, "I901", |
|
394 rewritten) |
|
395 |
|
396 def __checkBannedImport(self, node): |
|
397 """ |
|
398 Private method to check import of banned modules. |
|
399 |
|
400 @param node reference to the node to be checked |
|
401 @type ast.AST |
|
402 """ |
|
403 if not bool(self.__bannedModules): |
|
404 return |
|
405 |
|
406 if isinstance(node, ast.Import): |
|
407 moduleNames = [alias.name for alias in node.names] |
|
408 elif isinstance(node, ast.ImportFrom): |
|
409 nodeModule = node.module or "" |
|
410 moduleNames = [nodeModule] |
|
411 for alias in node.names: |
|
412 moduleNames.append("{0}.{1}".format(nodeModule, alias.name)) |
|
413 else: |
|
414 return |
|
415 |
|
416 # Sort from most to least specific paths. |
|
417 moduleNames.sort(key=len, reverse=True) |
|
418 |
|
419 warned = set() |
|
420 |
|
421 for moduleName in moduleNames: |
|
422 if moduleName in self.__bannedModules: |
|
423 if any(mod.startswith(moduleName) for mod in warned): |
|
424 # Do not show an error for this line if we already showed |
|
425 # a more specific error. |
|
426 continue |
|
427 else: |
|
428 warned.add(moduleName) |
|
429 self.__error(node.lineno - 1, node.col_offset, "I902", |
|
430 moduleName) |
|
431 |
|
432 def __checkBannedRelativeImports(self, node): |
|
433 """ |
|
434 Private method to check if relative imports are banned. |
|
435 |
|
436 @param node reference to the node to be checked |
|
437 @type ast.AST |
|
438 """ |
|
439 if not self.__banRelativeImports: |
|
440 return |
|
441 |
|
442 elif self.__banRelativeImports == "parents": |
|
443 minNodeLevel = 1 |
|
444 msgCode = "I903" |
|
445 else: |
|
446 minNodeLevel = 0 |
|
447 msgCode = "I904" |
|
448 |
|
449 if ( |
|
450 self.__banRelativeImports and |
|
451 isinstance(node, ast.ImportFrom) and |
|
452 node.level > minNodeLevel |
|
453 ): |
|
454 self.__error(node.lineno - 1, node.col_offset, msgCode) |