diff -r 3fc8dfeb6ebe -r b99e7fd55fd3 src/eric7/Plugins/CheckerPlugins/CodeStyleChecker/Security/SecurityNodeVisitor.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/eric7/Plugins/CheckerPlugins/CodeStyleChecker/Security/SecurityNodeVisitor.py Thu Jul 07 11:23:56 2022 +0200 @@ -0,0 +1,309 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2020 - 2022 Detlev Offenbach <detlev@die-offenbachs.de> +# + +""" +Module implementing an AST node visitor for security checks. +""" + +import ast + +from . import SecurityUtils +from .SecurityContext import SecurityContext + + +class SecurityNodeVisitor: + """ + Class implementing an AST node visitor for security checks. + """ + def __init__(self, checker, secCheckers, filename): + """ + Constructor + + @param checker reference to the main security checker object + @type SecurityChecker + @param secCheckers dictionary containing the available checker routines + @type dict + @param filename name of the checked file + @type str + """ + self.__checker = checker + self.__securityCheckers = secCheckers + + self.seen = 0 + self.depth = 0 + self.filename = filename + self.imports = set() + self.import_aliases = {} + + # in some cases we can't determine a qualified name + try: + self.namespace = SecurityUtils.getModuleQualnameFromPath(filename) + except SecurityUtils.InvalidModulePath: + self.namespace = "" + + def __runChecks(self, checkType): + """ + Private method to run all enabled checks for a given check type. + + @param checkType type of checks to be run + @type str + """ + if checkType in self.__securityCheckers: + for check in self.__securityCheckers[checkType]: + check(self.__checker.reportError, + SecurityContext(self.__context), + self.__checker.getConfig()) + + def visit_ClassDef(self, node): + """ + Public method defining a visitor for AST ClassDef nodes. + + Add class name to current namespace for all descendants. + + @param node reference to the node being inspected + @type ast.ClassDef + """ + # For all child nodes, add this class name to current namespace + self.namespace = SecurityUtils.namespacePathJoin( + self.namespace, node.name) + + def visit_FunctionDef(self, node): + """ + Public method defining a visitor for AST FunctionDef nodes. + + @param node reference to the node being inspected + @type ast.FunctionDef + """ + self.__visitFunctionDefinition(node) + + def visit_AsyncFunctionDef(self, node): + """ + Public method defining a visitor for AST AsyncFunctionDef nodes. + + @param node reference to the node being inspected + @type ast.AsyncFunctionDef + """ + self.__visitFunctionDefinition(node) + + def __visitFunctionDefinition(self, node): + """ + Private method defining a visitor for AST FunctionDef and + AsyncFunctionDef nodes. + + Add relevant information about the node to the context for use in tests + which inspect function definitions. Add the function name to the + current namespace for all descendants. + + @param node reference to the node being inspected + @type ast.FunctionDef, ast.AsyncFunctionDef + """ + self.__context['function'] = node + qualname = SecurityUtils.namespacePathJoin(self.namespace, node.name) + name = qualname.split('.')[-1] + self.__context['qualname'] = qualname + self.__context['name'] = name + + # For all child nodes and any tests run, add this function name to + # current namespace + self.namespace = SecurityUtils.namespacePathJoin( + self.namespace, node.name) + + self.__runChecks("FunctionDef") + + def visit_Call(self, node): + """ + Public method defining a visitor for AST Call nodes. + + Add relevant information about the node to the context for use in tests + which inspect function calls. + + @param node reference to the node being inspected + @type ast.Call + """ + self.__context['call'] = node + qualname = SecurityUtils.getCallName(node, self.import_aliases) + name = qualname.split('.')[-1] + self.__context['qualname'] = qualname + self.__context['name'] = name + self.__runChecks("Call") + + def visit_Import(self, node): + """ + Public method defining a visitor for AST Import nodes. + + @param node reference to the node being inspected + @type ast.Import + """ + for nodename in node.names: + if nodename.asname: + self.import_aliases[nodename.asname] = nodename.name + self.imports.add(nodename.name) + self.__context['module'] = nodename.name + self.__runChecks("Import") + + def visit_ImportFrom(self, node): + """ + Public method defining a visitor for AST Import nodes. + + This adds relevant information about the node to + the context for use in tests which inspect imports. + + @param node reference to the node being inspected + @type ast.ImportFrom + """ + module = node.module + if module is None: + self.visit_Import(node) + return + + for nodename in node.names: + if nodename.asname: + self.import_aliases[nodename.asname] = ( + module + "." + nodename.name + ) + else: + # Even if import is not aliased we need an entry that maps + # name to module.name. For example, with 'from a import b' + # b should be aliased to the qualified name a.b + self.import_aliases[nodename.name] = ( + module + '.' + nodename.name) + self.imports.add(module + "." + nodename.name) + self.__context['module'] = module + self.__context['name'] = nodename.name + self.__runChecks("ImportFrom") + + def visit_Constant(self, node): + """ + Public method defining a visitor for Constant nodes. + + This calls the appropriate method for the node type. + It maintains compatibility with <3.6 and 3.8+ + + @param node reference to the node being inspected + @type ast.Constant + """ + if isinstance(node.value, str): + self.visit_Str(node) + elif isinstance(node.value, bytes): + self.visit_Bytes(node) + + def visit_Str(self, node): + """ + Public method defining a visitor for String nodes. + + This adds relevant information about node to + the context for use in tests which inspect strings. + + @param node reference to the node being inspected + @type ast.Str + """ + self.__context['str'] = node.s + if not isinstance(node._securityParent, ast.Expr): # docstring + self.__context['linerange'] = SecurityUtils.linerange_fix( + node._securityParent + ) + self.__runChecks("Str") + + def visit_Bytes(self, node): + """ + Public method defining a visitor for Bytes nodes. + + This adds relevant information about node to + the context for use in tests which inspect strings. + + @param node reference to the node being inspected + @type ast.Bytes + """ + self.__context['bytes'] = node.s + if not isinstance(node._securityParent, ast.Expr): # docstring + self.__context['linerange'] = SecurityUtils.linerange_fix( + node._securityParent + ) + self.__runChecks("Bytes") + + def __preVisit(self, node): + """ + Private method to set up a context for the visit method. + + @param node node to base the context on + @type ast.AST + @return flag indicating to visit the node + @rtype bool + """ + self.__context = {} + self.__context['imports'] = self.imports + self.__context['import_aliases'] = self.import_aliases + + if hasattr(node, 'lineno'): + self.__context['lineno'] = node.lineno + + self.__context['node'] = node + self.__context['linerange'] = SecurityUtils.linerange_fix(node) + self.__context['filename'] = self.filename + + self.seen += 1 + self.depth += 1 + + return True + + def visit(self, node): + """ + Public method to inspected an AST node. + + @param node AST node to be inspected + @type ast.AST + """ + name = node.__class__.__name__ + method = 'visit_' + name + visitor = getattr(self, method, None) + if visitor is not None: + visitor(node) + else: + self.__runChecks(name) + + def __postVisit(self, node): + """ + Private method to clean up after a node was visited. + + @param node AST node that was visited + @type ast.AST + """ + self.depth -= 1 + # Clean up post-recursion stuff that gets setup in the visit methods + # for these node types. + if isinstance(node, (ast.FunctionDef, ast.ClassDef)): + self.namespace = SecurityUtils.namespacePathSplit( + self.namespace)[0] + + def generic_visit(self, node): + """ + Public method to drive the node visitor. + + @param node node to be inspected + @type ast.AST + """ + for _, value in ast.iter_fields(node): + if isinstance(value, list): + maxIndex = len(value) - 1 + for index, item in enumerate(value): + if isinstance(item, ast.AST): + if index < maxIndex: + item._securitySibling = value[index + 1] + else: + item._securitySibling = None + item._securityParent = node + + if self.__preVisit(item): + self.visit(item) + self.generic_visit(item) + self.__postVisit(item) + + elif isinstance(value, ast.AST): + value._securitySibling = None + value._securityParent = node + if self.__preVisit(value): + self.visit(value) + self.generic_visit(value) + self.__postVisit(value)