|
1 # -*- coding: utf-8 -*- |
|
2 |
|
3 # Copyright (c) 2021 - 2022 Detlev Offenbach <detlev@die-offenbachs.de> |
|
4 # |
|
5 |
|
6 """ |
|
7 Module implementing a node visitor for checking local import statements. |
|
8 """ |
|
9 |
|
10 import ast |
|
11 |
|
12 # |
|
13 # The visitor is adapted from flake8-local-import v1.0.6 |
|
14 # |
|
15 |
|
16 |
|
17 class LocalImportVisitor(ast.NodeVisitor): |
|
18 """ |
|
19 Class implementing a node visitor for checking local import statements. |
|
20 """ |
|
21 def __init__(self, args, checker): |
|
22 """ |
|
23 Constructor |
|
24 |
|
25 @param args dictionary containing the checker arguments |
|
26 @type dict |
|
27 @param checker reference to the checker |
|
28 @type ImportsChecker |
|
29 """ |
|
30 self.__appImportNames = args.get("ApplicationPackageNames", []) |
|
31 self.__checker = checker |
|
32 |
|
33 self.violations = [] |
|
34 |
|
35 def visit(self, node): |
|
36 """ |
|
37 Public method to traverse the tree of an AST node. |
|
38 |
|
39 @param node AST node to parse |
|
40 @type ast.AST |
|
41 """ |
|
42 previous = None |
|
43 isLocal = ( |
|
44 isinstance(node, ast.FunctionDef) or |
|
45 getattr(node, 'is_local', False) |
|
46 ) |
|
47 for child in ast.iter_child_nodes(node): |
|
48 child.parent = node |
|
49 child.previous = previous |
|
50 child.is_local = isLocal |
|
51 previous = child |
|
52 |
|
53 super().visit(node) |
|
54 |
|
55 def visit_FunctionDef(self, node): |
|
56 """ |
|
57 Public method to handle a function definition. |
|
58 |
|
59 @param node reference to the node to be processed |
|
60 @type ast.FunctionDef |
|
61 """ |
|
62 children = list(ast.iter_child_nodes(node)) |
|
63 if len(children) > 1: |
|
64 firstStatement = children[1] |
|
65 |
|
66 if isinstance(firstStatement, ast.Expr): |
|
67 value = getattr(firstStatement, 'value', None) |
|
68 if isinstance(value, ast.Constant): |
|
69 firstStatement.is_doc_str = True |
|
70 |
|
71 self.generic_visit(node) |
|
72 |
|
73 def visit_Import(self, node): |
|
74 """ |
|
75 Public method to handle an import statement. |
|
76 |
|
77 @param node reference to the node to be processed |
|
78 @type ast.Import |
|
79 """ |
|
80 if not getattr(node, 'is_local', False): |
|
81 self.generic_visit(node) |
|
82 return |
|
83 |
|
84 for name in node.names: |
|
85 self.__assertExternalModule(node, name.name or '') |
|
86 |
|
87 self.__visitImportNode(node) |
|
88 |
|
89 def visit_ImportFrom(self, node): |
|
90 """ |
|
91 Public method to handle an import from statement. |
|
92 |
|
93 @param node reference to the node to be processed |
|
94 @type ast.ImportFrom |
|
95 """ |
|
96 if not getattr(node, 'is_local', False): |
|
97 self.generic_visit(node) |
|
98 return |
|
99 |
|
100 self.__assertExternalModule(node, node.module or '') |
|
101 |
|
102 self.__visitImportNode(node) |
|
103 |
|
104 def __visitImportNode(self, node): |
|
105 """ |
|
106 Private method to handle an import or import from statement. |
|
107 |
|
108 @param node reference to the node to be processed |
|
109 @type ast.Import or ast.ImportFrom |
|
110 """ |
|
111 parent = getattr(node, 'parent', None) |
|
112 if isinstance(parent, ast.Module): |
|
113 self.generic_visit(node) |
|
114 return |
|
115 |
|
116 previous = getattr(node, 'previous', None) |
|
117 |
|
118 isAllowedPrevious = ( |
|
119 (isinstance(previous, ast.Expr) and |
|
120 getattr(previous, 'is_doc_str', False)) or |
|
121 isinstance(previous, (ast.Import, ast.ImportFrom, ast.arguments)) |
|
122 ) |
|
123 |
|
124 if not isinstance(parent, ast.FunctionDef) or not isAllowedPrevious: |
|
125 self.violations.append((node, "I101")) |
|
126 |
|
127 self.generic_visit(node) |
|
128 |
|
129 def __assertExternalModule(self, node, module): |
|
130 """ |
|
131 Private method to assert the given node. |
|
132 |
|
133 @param node reference to the node to be processed |
|
134 @type ast.stmt |
|
135 @param module name of the module |
|
136 @type str |
|
137 """ |
|
138 parent = getattr(node, 'parent', None) |
|
139 if isinstance(parent, ast.Module): |
|
140 return |
|
141 |
|
142 modulePrefix = module + '.' |
|
143 |
|
144 if ( |
|
145 getattr(node, 'level', 0) != 0 or |
|
146 any(modulePrefix.startswith(appModule + '.') |
|
147 for appModule in self.__appImportNames) |
|
148 ): |
|
149 return |
|
150 |
|
151 if module.split('.')[0] not in self.__checker.getStandardModules(): |
|
152 self.violations.append((node, "I102")) |
|
153 else: |
|
154 self.violations.append((node, "I103")) |