|
1 # -*- coding: utf-8 -*- |
|
2 |
|
3 # Copyright (c) 2020 - 2022 Detlev Offenbach <detlev@die-offenbachs.de> |
|
4 # |
|
5 |
|
6 """ |
|
7 Module implementing an AST node visitor for security checks. |
|
8 """ |
|
9 |
|
10 import ast |
|
11 |
|
12 from . import SecurityUtils |
|
13 from .SecurityContext import SecurityContext |
|
14 |
|
15 |
|
16 class SecurityNodeVisitor: |
|
17 """ |
|
18 Class implementing an AST node visitor for security checks. |
|
19 """ |
|
20 def __init__(self, checker, secCheckers, filename): |
|
21 """ |
|
22 Constructor |
|
23 |
|
24 @param checker reference to the main security checker object |
|
25 @type SecurityChecker |
|
26 @param secCheckers dictionary containing the available checker routines |
|
27 @type dict |
|
28 @param filename name of the checked file |
|
29 @type str |
|
30 """ |
|
31 self.__checker = checker |
|
32 self.__securityCheckers = secCheckers |
|
33 |
|
34 self.seen = 0 |
|
35 self.depth = 0 |
|
36 self.filename = filename |
|
37 self.imports = set() |
|
38 self.import_aliases = {} |
|
39 |
|
40 # in some cases we can't determine a qualified name |
|
41 try: |
|
42 self.namespace = SecurityUtils.getModuleQualnameFromPath(filename) |
|
43 except SecurityUtils.InvalidModulePath: |
|
44 self.namespace = "" |
|
45 |
|
46 def __runChecks(self, checkType): |
|
47 """ |
|
48 Private method to run all enabled checks for a given check type. |
|
49 |
|
50 @param checkType type of checks to be run |
|
51 @type str |
|
52 """ |
|
53 if checkType in self.__securityCheckers: |
|
54 for check in self.__securityCheckers[checkType]: |
|
55 check(self.__checker.reportError, |
|
56 SecurityContext(self.__context), |
|
57 self.__checker.getConfig()) |
|
58 |
|
59 def visit_ClassDef(self, node): |
|
60 """ |
|
61 Public method defining a visitor for AST ClassDef nodes. |
|
62 |
|
63 Add class name to current namespace for all descendants. |
|
64 |
|
65 @param node reference to the node being inspected |
|
66 @type ast.ClassDef |
|
67 """ |
|
68 # For all child nodes, add this class name to current namespace |
|
69 self.namespace = SecurityUtils.namespacePathJoin( |
|
70 self.namespace, node.name) |
|
71 |
|
72 def visit_FunctionDef(self, node): |
|
73 """ |
|
74 Public method defining a visitor for AST FunctionDef nodes. |
|
75 |
|
76 @param node reference to the node being inspected |
|
77 @type ast.FunctionDef |
|
78 """ |
|
79 self.__visitFunctionDefinition(node) |
|
80 |
|
81 def visit_AsyncFunctionDef(self, node): |
|
82 """ |
|
83 Public method defining a visitor for AST AsyncFunctionDef nodes. |
|
84 |
|
85 @param node reference to the node being inspected |
|
86 @type ast.AsyncFunctionDef |
|
87 """ |
|
88 self.__visitFunctionDefinition(node) |
|
89 |
|
90 def __visitFunctionDefinition(self, node): |
|
91 """ |
|
92 Private method defining a visitor for AST FunctionDef and |
|
93 AsyncFunctionDef nodes. |
|
94 |
|
95 Add relevant information about the node to the context for use in tests |
|
96 which inspect function definitions. Add the function name to the |
|
97 current namespace for all descendants. |
|
98 |
|
99 @param node reference to the node being inspected |
|
100 @type ast.FunctionDef, ast.AsyncFunctionDef |
|
101 """ |
|
102 self.__context['function'] = node |
|
103 qualname = SecurityUtils.namespacePathJoin(self.namespace, node.name) |
|
104 name = qualname.split('.')[-1] |
|
105 self.__context['qualname'] = qualname |
|
106 self.__context['name'] = name |
|
107 |
|
108 # For all child nodes and any tests run, add this function name to |
|
109 # current namespace |
|
110 self.namespace = SecurityUtils.namespacePathJoin( |
|
111 self.namespace, node.name) |
|
112 |
|
113 self.__runChecks("FunctionDef") |
|
114 |
|
115 def visit_Call(self, node): |
|
116 """ |
|
117 Public method defining a visitor for AST Call nodes. |
|
118 |
|
119 Add relevant information about the node to the context for use in tests |
|
120 which inspect function calls. |
|
121 |
|
122 @param node reference to the node being inspected |
|
123 @type ast.Call |
|
124 """ |
|
125 self.__context['call'] = node |
|
126 qualname = SecurityUtils.getCallName(node, self.import_aliases) |
|
127 name = qualname.split('.')[-1] |
|
128 self.__context['qualname'] = qualname |
|
129 self.__context['name'] = name |
|
130 self.__runChecks("Call") |
|
131 |
|
132 def visit_Import(self, node): |
|
133 """ |
|
134 Public method defining a visitor for AST Import nodes. |
|
135 |
|
136 @param node reference to the node being inspected |
|
137 @type ast.Import |
|
138 """ |
|
139 for nodename in node.names: |
|
140 if nodename.asname: |
|
141 self.import_aliases[nodename.asname] = nodename.name |
|
142 self.imports.add(nodename.name) |
|
143 self.__context['module'] = nodename.name |
|
144 self.__runChecks("Import") |
|
145 |
|
146 def visit_ImportFrom(self, node): |
|
147 """ |
|
148 Public method defining a visitor for AST Import nodes. |
|
149 |
|
150 This adds relevant information about the node to |
|
151 the context for use in tests which inspect imports. |
|
152 |
|
153 @param node reference to the node being inspected |
|
154 @type ast.ImportFrom |
|
155 """ |
|
156 module = node.module |
|
157 if module is None: |
|
158 self.visit_Import(node) |
|
159 return |
|
160 |
|
161 for nodename in node.names: |
|
162 if nodename.asname: |
|
163 self.import_aliases[nodename.asname] = ( |
|
164 module + "." + nodename.name |
|
165 ) |
|
166 else: |
|
167 # Even if import is not aliased we need an entry that maps |
|
168 # name to module.name. For example, with 'from a import b' |
|
169 # b should be aliased to the qualified name a.b |
|
170 self.import_aliases[nodename.name] = ( |
|
171 module + '.' + nodename.name) |
|
172 self.imports.add(module + "." + nodename.name) |
|
173 self.__context['module'] = module |
|
174 self.__context['name'] = nodename.name |
|
175 self.__runChecks("ImportFrom") |
|
176 |
|
177 def visit_Constant(self, node): |
|
178 """ |
|
179 Public method defining a visitor for Constant nodes. |
|
180 |
|
181 This calls the appropriate method for the node type. |
|
182 It maintains compatibility with <3.6 and 3.8+ |
|
183 |
|
184 @param node reference to the node being inspected |
|
185 @type ast.Constant |
|
186 """ |
|
187 if isinstance(node.value, str): |
|
188 self.visit_Str(node) |
|
189 elif isinstance(node.value, bytes): |
|
190 self.visit_Bytes(node) |
|
191 |
|
192 def visit_Str(self, node): |
|
193 """ |
|
194 Public method defining a visitor for String nodes. |
|
195 |
|
196 This adds relevant information about node to |
|
197 the context for use in tests which inspect strings. |
|
198 |
|
199 @param node reference to the node being inspected |
|
200 @type ast.Str |
|
201 """ |
|
202 self.__context['str'] = node.s |
|
203 if not isinstance(node._securityParent, ast.Expr): # docstring |
|
204 self.__context['linerange'] = SecurityUtils.linerange_fix( |
|
205 node._securityParent |
|
206 ) |
|
207 self.__runChecks("Str") |
|
208 |
|
209 def visit_Bytes(self, node): |
|
210 """ |
|
211 Public method defining a visitor for Bytes nodes. |
|
212 |
|
213 This adds relevant information about node to |
|
214 the context for use in tests which inspect strings. |
|
215 |
|
216 @param node reference to the node being inspected |
|
217 @type ast.Bytes |
|
218 """ |
|
219 self.__context['bytes'] = node.s |
|
220 if not isinstance(node._securityParent, ast.Expr): # docstring |
|
221 self.__context['linerange'] = SecurityUtils.linerange_fix( |
|
222 node._securityParent |
|
223 ) |
|
224 self.__runChecks("Bytes") |
|
225 |
|
226 def __preVisit(self, node): |
|
227 """ |
|
228 Private method to set up a context for the visit method. |
|
229 |
|
230 @param node node to base the context on |
|
231 @type ast.AST |
|
232 @return flag indicating to visit the node |
|
233 @rtype bool |
|
234 """ |
|
235 self.__context = {} |
|
236 self.__context['imports'] = self.imports |
|
237 self.__context['import_aliases'] = self.import_aliases |
|
238 |
|
239 if hasattr(node, 'lineno'): |
|
240 self.__context['lineno'] = node.lineno |
|
241 |
|
242 self.__context['node'] = node |
|
243 self.__context['linerange'] = SecurityUtils.linerange_fix(node) |
|
244 self.__context['filename'] = self.filename |
|
245 |
|
246 self.seen += 1 |
|
247 self.depth += 1 |
|
248 |
|
249 return True |
|
250 |
|
251 def visit(self, node): |
|
252 """ |
|
253 Public method to inspected an AST node. |
|
254 |
|
255 @param node AST node to be inspected |
|
256 @type ast.AST |
|
257 """ |
|
258 name = node.__class__.__name__ |
|
259 method = 'visit_' + name |
|
260 visitor = getattr(self, method, None) |
|
261 if visitor is not None: |
|
262 visitor(node) |
|
263 else: |
|
264 self.__runChecks(name) |
|
265 |
|
266 def __postVisit(self, node): |
|
267 """ |
|
268 Private method to clean up after a node was visited. |
|
269 |
|
270 @param node AST node that was visited |
|
271 @type ast.AST |
|
272 """ |
|
273 self.depth -= 1 |
|
274 # Clean up post-recursion stuff that gets setup in the visit methods |
|
275 # for these node types. |
|
276 if isinstance(node, (ast.FunctionDef, ast.ClassDef)): |
|
277 self.namespace = SecurityUtils.namespacePathSplit( |
|
278 self.namespace)[0] |
|
279 |
|
280 def generic_visit(self, node): |
|
281 """ |
|
282 Public method to drive the node visitor. |
|
283 |
|
284 @param node node to be inspected |
|
285 @type ast.AST |
|
286 """ |
|
287 for _, value in ast.iter_fields(node): |
|
288 if isinstance(value, list): |
|
289 maxIndex = len(value) - 1 |
|
290 for index, item in enumerate(value): |
|
291 if isinstance(item, ast.AST): |
|
292 if index < maxIndex: |
|
293 item._securitySibling = value[index + 1] |
|
294 else: |
|
295 item._securitySibling = None |
|
296 item._securityParent = node |
|
297 |
|
298 if self.__preVisit(item): |
|
299 self.visit(item) |
|
300 self.generic_visit(item) |
|
301 self.__postVisit(item) |
|
302 |
|
303 elif isinstance(value, ast.AST): |
|
304 value._securitySibling = None |
|
305 value._securityParent = node |
|
306 if self.__preVisit(value): |
|
307 self.visit(value) |
|
308 self.generic_visit(value) |
|
309 self.__postVisit(value) |