--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/eric7/Plugins/CheckerPlugins/CodeStyleChecker/Miscellaneous/BugBearVisitor.py Thu Feb 27 14:42:39 2025 +0100 @@ -0,0 +1,2101 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2025 Detlev Offenbach <detlev@die-offenbachs.de> +# + +""" +Module implementing a visitor to check for various potential issues. +""" + +import ast +import builtins +import contextlib +import itertools +import math +import re + +from collections import Counter, namedtuple +from dataclasses import dataclass +from keyword import iskeyword + +import AstUtilities + +####################################################################### +## adapted from: flake8-bugbear v24.12.12 +## +## Original: Copyright (c) 2016 Ćukasz Langa +####################################################################### + +BugbearMutableLiterals = ("Dict", "List", "Set") +BugbearMutableComprehensions = ("ListComp", "DictComp", "SetComp") +BugbearMutableCalls = ( + "Counter", + "OrderedDict", + "collections.Counter", + "collections.OrderedDict", + "collections.defaultdict", + "collections.deque", + "defaultdict", + "deque", + "dict", + "list", + "set", +) +BugbearImmutableCalls = ( + "tuple", + "frozenset", + "types.MappingProxyType", + "MappingProxyType", + "re.compile", + "operator.attrgetter", + "operator.itemgetter", + "operator.methodcaller", + "attrgetter", + "itemgetter", + "methodcaller", +) + + +BugBearContext = namedtuple("BugBearContext", ["node", "stack"]) + + +def composeCallPath(node): + """ + Generator function to assemble the call path of a given node. + + @param node node to assemble call path for + @type ast.Node + @yield call path components + @ytype str + """ + if isinstance(node, ast.Attribute): + yield from composeCallPath(node.value) + yield node.attr + elif isinstance(node, ast.Call): + yield from composeCallPath(node.func) + elif isinstance(node, ast.Name): + yield node.id + + +@dataclass +class M540CaughtException: + """ + Class to hold the data for a caught exception. + """ + + name: str + hasNote: bool + + +class M541UnhandledKeyType: + """ + Class to hold a dictionary key of a type that we do not check for duplicates. + """ + + +class M541VariableKeyType: + """ + Class to hold the name of a variable key type. + """ + + def __init__(self, name): + """ + Constructor + + @param name name of the variable key type + @type str + """ + self.name = name + + +class BugBearVisitor(ast.NodeVisitor): + """ + Class implementing a node visitor to check for various topics. + """ + + CONTEXTFUL_NODES = ( + ast.Module, + ast.ClassDef, + ast.AsyncFunctionDef, + ast.FunctionDef, + ast.Lambda, + ast.ListComp, + ast.SetComp, + ast.DictComp, + ast.GeneratorExp, + ) + + FUNCTION_NODES = ( + ast.AsyncFunctionDef, + ast.FunctionDef, + ast.Lambda, + ) + + NodeWindowSize = 4 + + def __init__(self): + """ + Constructor + """ + super().__init__() + + self.nodeWindow = [] + self.violations = [] + self.contexts = [] + + self.__M523Seen = set() + self.__M505Imports = set() + self.__M540CaughtException = None + + self.__inTryStar = "" + + @property + def nodeStack(self): + """ + Public method to get a reference to the most recent node stack. + + @return reference to the most recent node stack + @rtype list + """ + if len(self.contexts) == 0: + return [] + + context, stack = self.contexts[-1] + return stack + + def __isIdentifier(self, arg): + """ + Private method to check if arg is a valid identifier. + + See https://docs.python.org/2/reference/lexical_analysis.html#identifiers + + @param arg reference to an argument node + @type ast.Node + @return flag indicating a valid identifier + @rtype TYPE + """ + if not AstUtilities.isString(arg): + return False + + return ( + re.match(r"^[A-Za-z_][A-Za-z0-9_]*$", AstUtilities.getValue(arg)) + is not None + ) + + def toNameStr(self, node): + """ + Public method to turn Name and Attribute nodes to strings, handling any + depth of attribute accesses. + + + @param node reference to the node + @type ast.Name or ast.Attribute + @return string representation + @rtype str + """ + if isinstance(node, ast.Name): + return node.id + elif isinstance(node, ast.Call): + return self.toNameStr(node.func) + elif isinstance(node, ast.Attribute): + inner = self.toNameStr(node.value) + if inner is None: + return None + return f"{inner}.{node.attr}" + else: + return None + + def __typesafeIssubclass(self, obj, classOrTuple): + """ + Private method implementing a type safe issubclass() function. + + @param obj reference to the object to be tested + @type Any + @param classOrTuple type to check against + @type type + @return flag indicating a subclass + @rtype bool + """ + try: + return issubclass(obj, classOrTuple) + except TypeError: + # User code specifies a type that is not a type in our current run. + # Might be their error, might be a difference in our environments. + # We don't know so we ignore this. + return False + + def __getAssignedNames(self, loopNode): + """ + Private method to get the names of a for loop. + + @param loopNode reference to the node to be processed + @type ast.For + @yield DESCRIPTION + @ytype TYPE + """ + loopTargets = (ast.For, ast.AsyncFor, ast.comprehension) + for node in self.__childrenInScope(loopNode): + if isinstance(node, (ast.Assign)): + for child in node.targets: + yield from self.__namesFromAssignments(child) + if isinstance(node, loopTargets + (ast.AnnAssign, ast.AugAssign)): + yield from self.__namesFromAssignments(node.target) + + def __namesFromAssignments(self, assignTarget): + """ + Private method to get names of an assignment. + + @param assignTarget reference to the node to be processed + @type ast.Node + @yield name of the assignment + @ytype str + """ + if isinstance(assignTarget, ast.Name): + yield assignTarget.id + elif isinstance(assignTarget, ast.Starred): + yield from self.__namesFromAssignments(assignTarget.value) + elif isinstance(assignTarget, (ast.List, ast.Tuple)): + for child in assignTarget.elts: + yield from self.__namesFromAssignments(child) + + def __childrenInScope(self, node): + """ + Private method to get all child nodes in the given scope. + + @param node reference to the node to be processed + @type ast.Node + @yield reference to a child node + @ytype ast.Node + """ + yield node + if not isinstance(node, BugBearVisitor.FUNCTION_NODES): + for child in ast.iter_child_nodes(node): + yield from self.__childrenInScope(child) + + def __flattenExcepthandler(self, node): + """ + Private method to flatten the list of exceptions handled by an except handler. + + @param node reference to the node to be processed + @type ast.Node + @yield reference to the exception type node + @ytype ast.Node + """ + if not isinstance(node, ast.Tuple): + yield node + return + + exprList = node.elts.copy() + while len(exprList): + expr = exprList.pop(0) + if isinstance(expr, ast.Starred) and isinstance( + expr.value, (ast.List, ast.Tuple) + ): + exprList.extend(expr.value.elts) + continue + yield expr + + def __checkRedundantExcepthandlers(self, names, node, inTryStar): + """ + Private method to check for redundant exception types in an exception handler. + + @param names list of exception types to be checked + @type list of ast.Name + @param node reference to the exception handler node + @type ast.ExceptionHandler + @param inTryStar character indicating an 'except*' handler + @type str + @return tuple containing the error data + @rtype tuple of (ast.Node, str, str, str, str) + """ + redundantExceptions = { + "OSError": { + # All of these are actually aliases of OSError since Python 3.3 + "IOError", + "EnvironmentError", + "WindowsError", + "mmap.error", + "socket.error", + "select.error", + }, + "ValueError": { + "binascii.Error", + }, + } + + # See if any of the given exception names could be removed, e.g. from: + # (MyError, MyError) # duplicate names # noqa: M-891 + # (MyError, BaseException) # everything derives from the Base # noqa: M-891 + # (Exception, TypeError) # builtins where one subclasses another + # noqa: M-891 + # (IOError, OSError) # IOError is an alias of OSError since Python3.3 + # noqa: M-891 + # but note that other cases are impractical to handle from the AST. + # We expect this is mostly useful for users who do not have the + # builtin exception hierarchy memorised, and include a 'shadowed' + # subtype without realising that it's redundant. + good = sorted(set(names), key=names.index) + if "BaseException" in good: + good = ["BaseException"] + # Remove redundant exceptions that the automatic system either handles + # poorly (usually aliases) or can't be checked (e.g. it's not an + # built-in exception). + for primary, equivalents in redundantExceptions.items(): + if primary in good: + good = [g for g in good if g not in equivalents] + + for name, other in itertools.permutations(tuple(good), 2): + if ( + self.__typesafeIssubclass( + getattr(builtins, name, type), getattr(builtins, other, ()) + ) + and name in good + ): + good.remove(name) + if good != names: + desc = good[0] if len(good) == 1 else "({0})".format(", ".join(good)) + as_ = " as " + node.name if node.name is not None else "" + return (node, "M-514", ", ".join(names), as_, desc, inTryStar) + + return None + + def __walkList(self, nodes): + """ + Private method to walk a given list of nodes. + + @param nodes list of nodes to walk + @type list of ast.Node + @yield node references as determined by the ast.walk() function + @ytype ast.Node + """ + for node in nodes: + yield from ast.walk(node) + + def __getNamesFromTuple(self, node): + """ + Private method to get the names from an ast.Tuple node. + + @param node ast node to be processed + @type ast.Tuple + @yield names + @ytype str + """ + for dim in node.elts: + if isinstance(dim, ast.Name): + yield dim.id + elif isinstance(dim, ast.Tuple): + yield from self.__getNamesFromTuple(dim) + + def __getDictCompLoopAndNamedExprVarNames(self, node): + """ + Private method to get the names of comprehension loop variables. + + @param node ast node to be processed + @type ast.DictComp + @yield loop variable names + @ytype str + """ + finder = NamedExprFinder() + for gen in node.generators: + if isinstance(gen.target, ast.Name): + yield gen.target.id + elif isinstance(gen.target, ast.Tuple): + yield from self.__getNamesFromTuple(gen.target) + + finder.visit(gen.ifs) + + yield from finder.getNames().keys() + + def __inClassInit(self): + """ + Private method to check, if we are inside an '__init__' method. + + @return flag indicating being within the '__init__' method + @rtype bool + """ + return ( + len(self.contexts) >= 2 + and isinstance(self.contexts[-2].node, ast.ClassDef) + and isinstance(self.contexts[-1].node, ast.FunctionDef) + and self.contexts[-1].node.name == "__init__" + ) + + def visit_Return(self, node): + """ + Public method to handle 'Return' nodes. + + @param node reference to the node to be processed + @type ast.Return + """ + if self.__inClassInit() and node.value is not None: + self.violations.append((node, "M-537")) + + self.generic_visit(node) + + def visit_Yield(self, node): + """ + Public method to handle 'Yield' nodes. + + @param node reference to the node to be processed + @type ast.Yield + """ + if self.__inClassInit(): + self.violations.append((node, "M-537")) + + self.generic_visit(node) + + def visit_YieldFrom(self, node) -> None: + """ + Public method to handle 'YieldFrom' nodes. + + @param node reference to the node to be processed + @type ast.YieldFrom + """ + if self.__inClassInit(): + self.violations.append((node, "M-537")) + + self.generic_visit(node) + + def visit(self, node): + """ + Public method to traverse a given AST node. + + @param node AST node to be traversed + @type ast.Node + """ + isContextful = isinstance(node, BugBearVisitor.CONTEXTFUL_NODES) + + if isContextful: + context = BugBearContext(node, []) + self.contexts.append(context) + + self.nodeStack.append(node) + self.nodeWindow.append(node) + self.nodeWindow = self.nodeWindow[-BugBearVisitor.NodeWindowSize :] + + super().visit(node) + + self.nodeStack.pop() + + if isContextful: + self.contexts.pop() + + self.__checkForM518(node) + + def visit_ExceptHandler(self, node): + """ + Public method to handle exception handlers. + + @param node reference to the node to be processed + @type ast.ExceptHandler + """ + if node.type is None: + # bare except is handled by pycodestyle already + self.generic_visit(node) + return + + oldM540CaughtException = self.__M540CaughtException + if node.name is None: + self.__M540CaughtException = None + else: + self.__M540CaughtException = M540CaughtException(node.name, False) + + names = self.__checkForM513_M514_M529_M530(node) + + if "BaseException" in names and not ExceptBaseExceptionVisitor(node).reRaised(): + self.violations.append((node, "M-536")) + + self.generic_visit(node) + + if ( + self.__M540CaughtException is not None + and self.__M540CaughtException.hasNote + ): + self.violations.append((node, "M-540")) + self.__M540CaughtException = oldM540CaughtException + + def visit_UAdd(self, node): + """ + Public method to handle unary additions. + + @param node reference to the node to be processed + @type ast.UAdd + """ + trailingNodes = list(map(type, self.nodeWindow[-4:])) + if trailingNodes == [ast.UnaryOp, ast.UAdd, ast.UnaryOp, ast.UAdd]: + originator = self.nodeWindow[-4] + self.violations.append((originator, "M-502")) + + self.generic_visit(node) + + def visit_Call(self, node): + """ + Public method to handle a function call. + + @param node reference to the node to be processed + @type ast.Call + """ + isM540AddNote = False + + if isinstance(node.func, ast.Attribute): + self.__checkForM505(node) + isM540AddNote = self.__checkForM540AddNote(node.func) + else: + with contextlib.suppress(AttributeError, IndexError): + # bad super() call + if isinstance(node.func, ast.Name) and node.func.id == "super": + args = node.args + if ( + len(args) == 2 + and isinstance(args[0], ast.Attribute) + and isinstance(args[0].value, ast.Name) + and args[0].value.id == "self" + and args[0].attr == "__class__" + ): + self.violations.append((node, "M-582")) + + # bad getattr and setattr + if ( + node.func.id in ("getattr", "hasattr") + and node.args[1].value == "__call__" + ): + self.violations.append((node, "M-504")) + if ( + node.func.id == "getattr" + and len(node.args) == 2 + and self.__isIdentifier(node.args[1]) + and iskeyword(AstUtilities.getValue(node.args[1])) + ): + self.violations.append((node, "M-509")) + elif ( + node.func.id == "setattr" + and len(node.args) == 3 + and self.__isIdentifier(node.args[1]) + and iskeyword(AstUtilities.getValue(node.args[1])) + ): + self.violations.append((node, "M-510")) + + self.__checkForM526(node) + + self.__checkForM528(node) + self.__checkForM534(node) + self.__checkForM539(node) + + # no need for copying, if used in nested calls it will be set to None + currentM540CaughtException = self.__M540CaughtException + if not isM540AddNote: + self.__checkForM540Usage(node.args) + self.__checkForM540Usage(node.keywords) + + self.generic_visit(node) + + if isM540AddNote: + # Avoid nested calls within the parameter list using the variable itself. + # e.g. `e.add_note(str(e))` + self.__M540CaughtException = currentM540CaughtException + + def visit_Module(self, node): + """ + Public method to handle a module node. + + @param node reference to the node to be processed + @type ast.Module + """ + self.generic_visit(node) + + def visit_Assign(self, node): + """ + Public method to handle assignments. + + @param node reference to the node to be processed + @type ast.Assign + """ + self.__checkForM540Usage(node.value) + if len(node.targets) == 1: + target = node.targets[0] + if ( + isinstance(target, ast.Attribute) + and isinstance(target.value, ast.Name) + and (target.value.id, target.attr) == ("os", "environ") + ): + self.violations.append((node, "M-503")) + + self.generic_visit(node) + + def visit_For(self, node): + """ + Public method to handle 'for' statements. + + @param node reference to the node to be processed + @type ast.For + """ + self.__checkForM507(node) + self.__checkForM520(node) + self.__checkForM523(node) + self.__checkForM531(node) + self.__checkForM569(node) + + self.generic_visit(node) + + def visit_AsyncFor(self, node): + """ + Public method to handle 'for' statements. + + @param node reference to the node to be processed + @type ast.AsyncFor + """ + self.__checkForM507(node) + self.__checkForM520(node) + self.__checkForM523(node) + self.__checkForM531(node) + + self.generic_visit(node) + + def visit_While(self, node): + """ + Public method to handle 'while' statements. + + @param node reference to the node to be processed + @type ast.While + """ + self.__checkForM523(node) + + self.generic_visit(node) + + def visit_ListComp(self, node): + """ + Public method to handle list comprehensions. + + @param node reference to the node to be processed + @type ast.ListComp + """ + self.__checkForM523(node) + + self.generic_visit(node) + + def visit_SetComp(self, node): + """ + Public method to handle set comprehensions. + + @param node reference to the node to be processed + @type ast.SetComp + """ + self.__checkForM523(node) + + self.generic_visit(node) + + def visit_DictComp(self, node): + """ + Public method to handle dictionary comprehensions. + + @param node reference to the node to be processed + @type ast.DictComp + """ + self.__checkForM523(node) + self.__checkForM535(node) + + self.generic_visit(node) + + def visit_GeneratorExp(self, node): + """ + Public method to handle generator expressions. + + @param node reference to the node to be processed + @type ast.GeneratorExp + """ + self.__checkForM523(node) + + self.generic_visit(node) + + def visit_Assert(self, node): + """ + Public method to handle 'assert' statements. + + @param node reference to the node to be processed + @type ast.Assert + """ + if ( + AstUtilities.isNameConstant(node.test) + and AstUtilities.getValue(node.test) is False + ): + self.violations.append((node, "M-511")) + + self.generic_visit(node) + + def visit_AsyncFunctionDef(self, node): + """ + Public method to handle async function definitions. + + @param node reference to the node to be processed + @type ast.AsyncFunctionDef + """ + self.__checkForM506_M508(node) + + self.generic_visit(node) + + def visit_FunctionDef(self, node): + """ + Public method to handle function definitions. + + @param node reference to the node to be processed + @type ast.FunctionDef + """ + self.__checkForM506_M508(node) + self.__checkForM519(node) + self.__checkForM521(node) + + self.generic_visit(node) + + def visit_ClassDef(self, node): + """ + Public method to handle class definitions. + + @param node reference to the node to be processed + @type ast.ClassDef + """ + self.__checkForM521(node) + self.__checkForM524_M527(node) + + self.generic_visit(node) + + def visit_Try(self, node): + """ + Public method to handle 'try' statements. + + @param node reference to the node to be processed + @type ast.Try + """ + self.__checkForM512(node) + self.__checkForM525(node) + + self.generic_visit(node) + + def visit_TryStar(self, node): + """ + Public method to handle 'except*' statements. + + @param node reference to the node to be processed + @type ast.TryStar + """ + outerTryStar = self.__inTryStar + self.__inTryStar = "*" + self.visit_Try(node) + self.__inTryStar = outerTryStar + + def visit_Compare(self, node): + """ + Public method to handle comparison statements. + + @param node reference to the node to be processed + @type ast.Compare + """ + self.__checkForM515(node) + + self.generic_visit(node) + + def visit_Raise(self, node): + """ + Public method to handle 'raise' statements. + + @param node reference to the node to be processed + @type ast.Raise + """ + if node.exc is None: + self.__M540CaughtException = None + else: + self.__checkForM540Usage(node.exc) + self.__checkForM540Usage(node.cause) + self.__checkForM516(node) + + self.generic_visit(node) + + def visit_With(self, node): + """ + Public method to handle 'with' statements. + + @param node reference to the node to be processed + @type ast.With + """ + self.__checkForM517(node) + self.__checkForM522(node) + + self.generic_visit(node) + + def visit_JoinedStr(self, node): + """ + Public method to handle f-string arguments. + + @param node reference to the node to be processed + @type ast.JoinedStr + """ + for value in node.values: + if isinstance(value, ast.FormattedValue): + return + + self.violations.append((node, "M-581")) + + def visit_AnnAssign(self, node): + """ + Public method to check annotated assign statements. + + @param node reference to the node to be processed + @type ast.AnnAssign + """ + self.__checkForM532(node) + self.__checkForM540Usage(node.value) + + self.generic_visit(node) + + def visit_Import(self, node): + """ + Public method to check imports. + + @param node reference to the node to be processed + @type ast.Import + """ + self.__checkForM505(node) + + self.generic_visit(node) + + def visit_ImportFrom(self, node): + """ + Public method to check from imports. + + @param node reference to the node to be processed + @type ast.Import + """ + self.visit_Import(node) + + def visit_Set(self, node): + """ + Public method to check a set. + + @param node reference to the node to be processed + @type ast.Set + """ + self.__checkForM533(node) + + self.generic_visit(node) + + def visit_Dict(self, node): + """ + Public method to check a dictionary. + + @param node reference to the node to be processed + @type ast.Dict + """ + self.__checkForM541(node) + + self.generic_visit(node) + + def __checkForM505(self, node): + """ + Private method to check the use of *strip(). + + @param node reference to the node to be processed + @type ast.Call + """ + if isinstance(node, ast.Import): + for name in node.names: + self.__M505Imports.add(name.asname or name.name) + elif isinstance(node, ast.ImportFrom): + for name in node.names: + self.__M505Imports.add(f"{node.module}.{name.name or name.asname}") + elif isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute): + if node.func.attr not in ("lstrip", "rstrip", "strip"): + return # method name doesn't match + + if ( + isinstance(node.func.value, ast.Name) + and node.func.value.id in self.__M505Imports + ): + return # method is being run on an imported module + + if len(node.args) != 1 or not AstUtilities.isString(node.args[0]): + return # used arguments don't match the builtin strip + + value = AstUtilities.getValue(node.args[0]) + if len(value) == 1: + return # stripping just one character + + if len(value) == len(set(value)): + return # no characters appear more than once + + self.violations.append((node, "M-505")) + + def __checkForM506_M508(self, node): + """ + Private method to check the use of mutable literals, comprehensions and calls. + + @param node reference to the node to be processed + @type ast.AsyncFunctionDef or ast.FunctionDef + """ + visitor = FunctionDefDefaultsVisitor("M-506", "M-508") + visitor.visit(node.args.defaults + node.args.kw_defaults) + self.violations.extend(visitor.errors) + + def __checkForM507(self, node): + """ + Private method to check for unused loop variables. + + @param node reference to the node to be processed + @type ast.For or ast.AsyncFor + """ + targets = NameFinder() + targets.visit(node.target) + ctrlNames = set(filter(lambda s: not s.startswith("_"), targets.getNames())) + body = NameFinder() + for expr in node.body: + body.visit(expr) + usedNames = set(body.getNames()) + for name in sorted(ctrlNames - usedNames): + n = targets.getNames()[name][0] + self.violations.append((n, "M-507", name)) + + def __checkForM512(self, node): + """ + Private method to check for return/continue/break inside finally blocks. + + @param node reference to the node to be processed + @type ast.Try + """ + + def _loop(node, badNodeTypes): + if isinstance(node, (ast.AsyncFunctionDef, ast.FunctionDef)): + return + + if isinstance(node, (ast.While, ast.For)): + badNodeTypes = (ast.Return,) + + elif isinstance(node, badNodeTypes): + self.violations.append((node, "M-512", self.__inTryStar)) + + for child in ast.iter_child_nodes(node): + _loop(child, badNodeTypes) + + for child in node.finalbody: + _loop(child, (ast.Return, ast.Continue, ast.Break)) + + def __checkForM513_M514_M529_M530(self, node): + """ + Private method to check various exception handler situations. + + @param node reference to the node to be processed + @type ast.ExceptHandler + @return list of exception handler names + @rtype list of str + """ + handlers = self.__flattenExcepthandler(node.type) + names = [] + badHandlers = [] + ignoredHandlers = [] + + for handler in handlers: + if isinstance(handler, (ast.Name, ast.Attribute)): + name = self.toNameStr(handler) + if name is None: + ignoredHandlers.append(handler) + else: + names.append(name) + elif isinstance(handler, (ast.Call, ast.Starred)): + ignoredHandlers.append(handler) + else: + badHandlers.append(handler) + if badHandlers: + self.violations.append((node, "M-530")) + if len(names) == 0 and not badHandlers and not ignoredHandlers: + self.violations.append((node, "M-529", self.__inTryStar)) + elif ( + len(names) == 1 + and not badHandlers + and not ignoredHandlers + and isinstance(node.type, ast.Tuple) + ): + self.violations.append((node, "M-513", *names, self.__inTryStar)) + else: + maybeError = self.__checkRedundantExcepthandlers( + names, node, self.__inTryStar + ) + if maybeError is not None: + self.violations.append(maybeError) + return names + + def __checkForM515(self, node): + """ + Private method to check for pointless comparisons. + + @param node reference to the node to be processed + @type ast.Compare + """ + if isinstance(self.nodeStack[-2], ast.Expr): + self.violations.append((node, "M-515")) + + def __checkForM516(self, node): + """ + Private method to check for raising a literal instead of an exception. + + @param node reference to the node to be processed + @type ast.Raise + """ + if ( + AstUtilities.isNameConstant(node.exc) + or AstUtilities.isNumber(node.exc) + or AstUtilities.isString(node.exc) + ): + self.violations.append((node, "M-516")) + + def __checkForM517(self, node): + """ + Private method to check for use of the evil syntax + 'with assertRaises(Exception): or 'with pytest.raises(Exception):'. + + @param node reference to the node to be processed + @type ast.With + """ + item = node.items[0] + itemContext = item.context_expr + if ( + hasattr(itemContext, "func") + and ( + ( + isinstance(itemContext.func, ast.Attribute) + and ( + itemContext.func.attr == "assertRaises" + or ( + itemContext.func.attr == "raises" + and isinstance(itemContext.func.value, ast.Name) + and itemContext.func.value.id == "pytest" + and "match" not in (kwd.arg for kwd in itemContext.keywords) + ) + ) + ) + or ( + isinstance(itemContext.func, ast.Name) + and itemContext.func.id == "raises" + and isinstance(itemContext.func.ctx, ast.Load) + and "pytest.raises" in self.__M505Imports + and "match" not in (kwd.arg for kwd in itemContext.keywords) + ) + ) + and len(itemContext.args) == 1 + and isinstance(itemContext.args[0], ast.Name) + and itemContext.args[0].id in ("Exception", "BaseException") + and not item.optional_vars + ): + self.violations.append((node, "M-517")) + + def __checkForM518(self, node): + """ + Private method to check for useless expressions. + + @param node reference to the node to be processed + @type ast.FunctionDef + """ + if not isinstance(node, ast.Expr): + return + + if isinstance( + node.value, + (ast.List, ast.Set, ast.Dict, ast.Tuple), + ) or ( + isinstance(node.value, ast.Constant) + and ( + isinstance( + node.value.value, + (int, float, complex, bytes, bool), + ) + or node.value.value is None + ) + ): + self.violations.append((node, "M-518", node.value.__class__.__name__)) + + def __checkForM519(self, node): + """ + Private method to check for use of 'functools.lru_cache' or 'functools.cache'. + + @param node reference to the node to be processed + @type ast.FunctionDef + """ + caches = { + "functools.cache", + "functools.lru_cache", + "cache", + "lru_cache", + } + + if ( + len(node.decorator_list) == 0 + or len(self.contexts) < 2 + or not isinstance(self.contexts[-2].node, ast.ClassDef) + ): + return + + # Preserve decorator order so we can get the lineno from the decorator node + # rather than the function node (this location definition changes in Python 3.8) + resolvedDecorators = ( + ".".join(composeCallPath(decorator)) for decorator in node.decorator_list + ) + for idx, decorator in enumerate(resolvedDecorators): + if decorator in {"classmethod", "staticmethod"}: + return + + if decorator in caches: + self.violations.append((node.decorator_list[idx], "M-519")) + return + + def __checkForM520(self, node): + """ + Private method to check for a loop that modifies its iterable. + + @param node reference to the node to be processed + @type ast.For or ast.AsyncFor + """ + targets = NameFinder() + targets.visit(node.target) + ctrlNames = set(targets.getNames()) + + iterset = M520NameFinder() + iterset.visit(node.iter) + itersetNames = set(iterset.getNames()) + + for name in sorted(ctrlNames): + if name in itersetNames: + n = targets.getNames()[name][0] + self.violations.append((n, "M-520")) + + def __checkForM521(self, node): + """ + Private method to check for use of an f-string as docstring. + + @param node reference to the node to be processed + @type ast.FunctionDef or ast.ClassDef + """ + if ( + node.body + and isinstance(node.body[0], ast.Expr) + and isinstance(node.body[0].value, ast.JoinedStr) + ): + self.violations.append((node.body[0].value, "M-521")) + + def __checkForM522(self, node): + """ + Private method to check for use of an f-string as docstring. + + @param node reference to the node to be processed + @type ast.With + """ + item = node.items[0] + itemContext = item.context_expr + if ( + hasattr(itemContext, "func") + and hasattr(itemContext.func, "value") + and hasattr(itemContext.func.value, "id") + and itemContext.func.value.id == "contextlib" + and hasattr(itemContext.func, "attr") + and itemContext.func.attr == "suppress" + and len(itemContext.args) == 0 + ): + self.violations.append((node, "M-522")) + + def __checkForM523(self, loopNode): + """ + Private method to check that functions (including lambdas) do not use loop + variables. + + @param loopNode reference to the node to be processed + @type ast.For, ast.AsyncFor, ast.While, ast.ListComp, ast.SetComp,ast.DictComp, + or ast.GeneratorExp + """ + safe_functions = [] + suspiciousVariables = [] + for node in ast.walk(loopNode): + # check if function is immediately consumed to avoid false alarm + if isinstance(node, ast.Call): + # check for filter&reduce + if ( + isinstance(node.func, ast.Name) + and node.func.id in ("filter", "reduce", "map") + ) or ( + isinstance(node.func, ast.Attribute) + and node.func.attr == "reduce" + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "functools" + ): + for arg in node.args: + if isinstance(arg, BugBearVisitor.FUNCTION_NODES): + safe_functions.append(arg) + + # check for key= + for keyword in node.keywords: + if keyword.arg == "key" and isinstance( + keyword.value, BugBearVisitor.FUNCTION_NODES + ): + safe_functions.append(keyword.value) + + # mark `return lambda: x` as safe + # does not (currently) check inner lambdas in a returned expression + # e.g. `return (lambda: x, ) + if isinstance(node, ast.Return) and isinstance( + node.value, BugBearVisitor.FUNCTION_NODES + ): + safe_functions.append(node.value) + + # find unsafe functions + if ( + isinstance(node, BugBearVisitor.FUNCTION_NODES) + and node not in safe_functions + ): + argnames = { + arg.arg for arg in ast.walk(node.args) if isinstance(arg, ast.arg) + } + if isinstance(node, ast.Lambda): + bodyNodes = ast.walk(node.body) + else: + bodyNodes = itertools.chain.from_iterable(map(ast.walk, node.body)) + errors = [] + for name in bodyNodes: + if isinstance(name, ast.Name) and name.id not in argnames: + if isinstance(name.ctx, ast.Load): + errors.append((name.lineno, name.col_offset, name.id, name)) + elif isinstance(name.ctx, ast.Store): + argnames.add(name.id) + for err in errors: + if err[2] not in argnames and err not in self.__M523Seen: + self.__M523Seen.add(err) # dedupe across nested loops + suspiciousVariables.append(err) + + if suspiciousVariables: + reassignedInLoop = set(self.__getAssignedNames(loopNode)) + + for err in sorted(suspiciousVariables): + if reassignedInLoop.issuperset(err[2]): + self.violations.append((err[3], "M-523", err[2])) + + def __checkForM524_M527(self, node): + """ + Private method to check for inheritance from abstract classes in abc and lack of + any methods decorated with abstract*. + + @param node reference to the node to be processed + @type ast.ClassDef + """ # __IGNORE_WARNING_D-234r__ + + def isAbcClass(value, name="ABC"): + if isinstance(value, ast.keyword): + return value.arg == "metaclass" and isAbcClass(value.value, "ABCMeta") + + # class foo(ABC) + # class foo(abc.ABC) + return (isinstance(value, ast.Name) and value.id == name) or ( + isinstance(value, ast.Attribute) + and value.attr == name + and isinstance(value.value, ast.Name) + and value.value.id == "abc" + ) + + def isAbstractDecorator(expr): + return (isinstance(expr, ast.Name) and expr.id[:8] == "abstract") or ( + isinstance(expr, ast.Attribute) and expr.attr[:8] == "abstract" + ) + + def isOverload(expr): + return (isinstance(expr, ast.Name) and expr.id == "overload") or ( + isinstance(expr, ast.Attribute) and expr.attr == "overload" + ) + + def emptyBody(body): + def isStrOrEllipsis(node): + return isinstance(node, ast.Constant) and ( + node.value is Ellipsis or isinstance(node.value, str) + ) + + # Function body consist solely of `pass`, `...`, and/or (doc)string literals + return all( + isinstance(stmt, ast.Pass) + or (isinstance(stmt, ast.Expr) and isStrOrEllipsis(stmt.value)) + for stmt in body + ) + + # don't check multiple inheritance + if len(node.bases) + len(node.keywords) > 1: + return + + # only check abstract classes + if not any(map(isAbcClass, (*node.bases, *node.keywords))): + return + + hasMethod = False + hasAbstractMethod = False + + if not any(map(isAbcClass, (*node.bases, *node.keywords))): + return + + for stmt in node.body: + # Ignore abc's that declares a class attribute that must be set + if isinstance(stmt, ast.AnnAssign) and stmt.value is None: + hasAbstractMethod = True + continue + + # only check function defs + if not isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + hasMethod = True + + hasAbstractDecorator = any(map(isAbstractDecorator, stmt.decorator_list)) + + hasAbstractMethod |= hasAbstractDecorator + + if ( + not hasAbstractDecorator + and emptyBody(stmt.body) + and not any(map(isOverload, stmt.decorator_list)) + ): + self.violations.append((stmt, "M-527", stmt.name)) + + if hasMethod and not hasAbstractMethod: + self.violations.append((node, "M-524", node.name)) + + def __checkForM525(self, node): + """ + Private method to check for exceptions being handled multiple times. + + @param node reference to the node to be processed + @type ast.Try + """ + seen = [] + + for handler in node.handlers: + if isinstance(handler.type, (ast.Name, ast.Attribute)): + name = ".".join(composeCallPath(handler.type)) + seen.append(name) + elif isinstance(handler.type, ast.Tuple): + # to avoid checking the same as M514, remove duplicates per except + uniques = set() + for entry in handler.type.elts: + name = ".".join(composeCallPath(entry)) + uniques.add(name) + seen.extend(uniques) + + # sort to have a deterministic output + duplicates = sorted({x for x in seen if seen.count(x) > 1}) + for duplicate in duplicates: + self.violations.append((node, "M-525", duplicate, self.__inTryStar)) + + def __checkForM526(self, node): + """ + Private method to check for Star-arg unpacking after keyword argument. + + @param node reference to the node to be processed + @type ast.Call + """ + if not node.keywords: + return + + starreds = [arg for arg in node.args if isinstance(arg, ast.Starred)] + if not starreds: + return + + firstKeyword = node.keywords[0].value + for starred in starreds: + if (starred.lineno, starred.col_offset) > ( + firstKeyword.lineno, + firstKeyword.col_offset, + ): + self.violations.append((node, "M-526")) + + def __checkForM528(self, node): + """ + Private method to check for warn without stacklevel. + + @param node reference to the node to be processed + @type ast.Call + """ + if ( + isinstance(node.func, ast.Attribute) + and node.func.attr == "warn" + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "warnings" + and not any(kw.arg == "stacklevel" for kw in node.keywords) + and len(node.args) < 3 + and not any(isinstance(a, ast.Starred) for a in node.args) + and not any(kw.arg is None for kw in node.keywords) + ): + self.violations.append((node, "M-528")) + + def __checkForM531(self, loopNode): + """ + Private method to check that 'itertools.groupby' isn't iterated over more than + once. + + A warning is emitted when the generator returned by 'groupby()' is used + more than once inside a loop body or when it's used in a nested loop. + + @param loopNode reference to the node to be processed + @type ast.For or ast.AsyncFor + """ + # for <loop_node.target> in <loop_node.iter>: ... + if isinstance(loopNode.iter, ast.Call): + node = loopNode.iter + if (isinstance(node.func, ast.Name) and node.func.id in ("groupby",)) or ( + isinstance(node.func, ast.Attribute) + and node.func.attr == "groupby" + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "itertools" + ): + # We have an invocation of groupby which is a simple unpacking + if isinstance(loopNode.target, ast.Tuple) and isinstance( + loopNode.target.elts[1], ast.Name + ): + groupName = loopNode.target.elts[1].id + else: + # Ignore any 'groupby()' invocation that isn't unpacked + return + + numUsages = 0 + for node in self.__walkList(loopNode.body): + # Handled nested loops + if isinstance(node, ast.For): + for nestedNode in self.__walkList(node.body): + if ( + isinstance(nestedNode, ast.Name) + and nestedNode.id == groupName + ): + self.violations.append((nestedNode, "M-531")) + + # Handle multiple uses + if isinstance(node, ast.Name) and node.id == groupName: + numUsages += 1 + if numUsages > 1: + self.violations.append((nestedNode, "M-531")) + + def __checkForM532(self, node): + """ + Private method to check for possible unintentional typing annotation. + + @param node reference to the node to be processed + @type ast.AnnAssign + """ + if ( + node.value is None + and hasattr(node.target, "value") + and isinstance(node.target.value, ast.Name) + and ( + isinstance(node.target, ast.Subscript) + or ( + isinstance(node.target, ast.Attribute) + and node.target.value.id != "self" + ) + ) + ): + self.violations.append((node, "M-532")) + + def __checkForM533(self, node): + """ + Private method to check a set for duplicate items. + + @param node reference to the node to be processed + @type ast.Set + """ + seen = set() + for elt in node.elts: + if not isinstance(elt, ast.Constant): + continue + if elt.value in seen: + self.violations.append((node, "M-533", repr(elt.value))) + else: + seen.add(elt.value) + + def __checkForM534(self, node): + """ + Private method to check that re.sub/subn/split arguments flags/count/maxsplit + are passed as keyword arguments. + + @param node reference to the node to be processed + @type ast.Call + """ + if not isinstance(node.func, ast.Attribute): + return + func = node.func + if not isinstance(func.value, ast.Name) or func.value.id != "re": + return + + def check(numArgs, paramName): + if len(node.args) > numArgs: + arg = node.args[numArgs] + self.violations.append((arg, "M-534", func.attr, paramName)) + + if func.attr in ("sub", "subn"): + check(3, "count") + elif func.attr == "split": + check(2, "maxsplit") + + def __checkForM535(self, node): + """ + Private method to check that a static key isn't used in a dict comprehension. + + Record a warning if a likely unchanging key is used - either a constant, + or a variable that isn't coming from the generator expression. + + @param node reference to the node to be processed + @type ast.DictComp + """ + if isinstance(node.key, ast.Constant): + self.violations.append((node, "M-535", node.key.value)) + elif isinstance( + node.key, ast.Name + ) and node.key.id not in self.__getDictCompLoopAndNamedExprVarNames(node): + self.violations.append((node, "M-535", node.key.id)) + + def __checkForM539(self, node): + """ + Private method to check for correct ContextVar usage. + + @param node reference to the node to be processed + @type ast.Call + """ + if not ( + (isinstance(node.func, ast.Name) and node.func.id == "ContextVar") + or ( + isinstance(node.func, ast.Attribute) + and node.func.attr == "ContextVar" + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "contextvars" + ) + ): + return + + # ContextVar only takes one kw currently, but better safe than sorry + for kw in node.keywords: + if kw.arg == "default": + break + else: + return + + visitor = FunctionDefDefaultsVisitor("M-539", "M-539") + visitor.visit(kw.value) + self.violations.extend(visitor.errors) + + def __checkForM540AddNote(self, node): + """ + Private method to check add_note usage. + + @param node reference to the node to be processed + @type ast.Attribute + @return flag + @rtype bool + """ + if ( + node.attr == "add_note" + and isinstance(node.value, ast.Name) + and self.__M540CaughtException + and node.value.id == self.__M540CaughtException.name + ): + self.__M540CaughtException.hasNote = True + return True + + return False + + def __checkForM540Usage(self, node): + """ + Private method to check the usage of exceptions with added note. + + @param node reference to the node to be processed + @type ast.expr or None + """ # noqa: D-234y + + def superwalk(node: ast.AST | list[ast.AST]): + """ + Function to walk an AST node or a list of AST nodes. + + @param node reference to the node or a list of nodes to be processed + @type ast.AST or list[ast.AST] + @yield next node to be processed + @ytype ast.AST + """ + if isinstance(node, list): + for n in node: + yield from ast.walk(n) + else: + yield from ast.walk(node) + + if not self.__M540CaughtException or node is None: + return + + for n in superwalk(node): + if isinstance(n, ast.Name) and n.id == self.__M540CaughtException.name: + self.__M540CaughtException = None + break + + def __checkForM541(self, node): + """ + Private method to check for duplicate key value pairs in a dictionary literal. + + @param node reference to the node to be processed + @type ast.Dict + """ # noqa: D-234r + + def convertToValue(item): + """ + Function to extract the value of a given item. + + @param item node to extract value from + @type ast.Ast + @return value of the node + @rtype Any + """ + if isinstance(item, ast.Constant): + return item.value + elif isinstance(item, ast.Tuple): + return tuple(convertToValue(i) for i in item.elts) + elif isinstance(item, ast.Name): + return M541VariableKeyType(item.id) + else: + return M541UnhandledKeyType() + + keys = [convertToValue(key) for key in node.keys] + keyCounts = Counter(keys) + duplicateKeys = [key for key, count in keyCounts.items() if count > 1] + for key in duplicateKeys: + keyIndices = [i for i, iKey in enumerate(keys) if iKey == key] + seen = set() + for index in keyIndices: + value = convertToValue(node.values[index]) + if value in seen: + keyNode = node.keys[index] + self.violations.append((keyNode, "M-541")) + seen.add(value) + + def __checkForM569(self, node): + """ + Private method to check for changes to a loop's mutable iterable. + + @param node loop node to be checked + @type ast.For + """ + if isinstance(node.iter, ast.Name): + name = self.toNameStr(node.iter) + elif isinstance(node.iter, ast.Attribute): + name = self.toNameStr(node.iter) + else: + return + checker = M569Checker(name, self) + checker.visit(node.body) + for mutation in checker.mutations: + self.violations.append((mutation, "M-569")) + + +class M569Checker(ast.NodeVisitor): + """ + Class traversing a 'for' loop body to check for modifications to a loop's + mutable iterable. + """ + + # https://docs.python.org/3/library/stdtypes.html#mutable-sequence-types + MUTATING_FUNCTIONS = ( + "append", + "sort", + "reverse", + "remove", + "clear", + "extend", + "insert", + "pop", + "popitem", + ) + + def __init__(self, name, bugbear): + """ + Constructor + + @param name name of the iterator + @type str + @param bugbear reference to the bugbear visitor + @type BugBearVisitor + """ + self.__name = name + self.__bb = bugbear + self.mutations = [] + + def visit_Delete(self, node): + """ + Public method handling 'Delete' nodes. + + @param node reference to the node to be processed + @type ast.Delete + """ + for target in node.targets: + if isinstance(target, ast.Subscript): + name = self.__bb.toNameStr(target.value) + elif isinstance(target, (ast.Attribute, ast.Name)): + name = self.__bb.toNameStr(target) + else: + name = "" # fallback + self.generic_visit(target) + + if name == self.__name: + self.mutations.append(node) + + def visit_Call(self, node): + """ + Public method handling 'Call' nodes. + + @param node reference to the node to be processed + @type ast.Call + """ + if isinstance(node.func, ast.Attribute): + name = self.__bb.toNameStr(node.func.value) + functionObject = name + functionName = node.func.attr + + if ( + functionObject == self.__name + and functionName in self.MUTATING_FUNCTIONS + ): + self.mutations.append(node) + + self.generic_visit(node) + + def visit(self, node): + """ + Public method to inspect an ast node. + + Like super-visit but supports iteration over lists. + + @param node AST node to be traversed + @type TYPE + @return reference to the last processed node + @rtype ast.Node + """ + if not isinstance(node, list): + return super().visit(node) + + for elem in node: + super().visit(elem) + return node + + +class NamedExprFinder(ast.NodeVisitor): + """ + Class to extract names defined through an ast.NamedExpr. + """ + + def __init__(self): + """ + Constructor + """ + super().__init__() + + self.__names = {} + + def visit_NamedExpr(self, node: ast.NamedExpr): + """ + Public method handling 'NamedExpr' nodes. + + @param node reference to the node to be processed + @type ast.NamedExpr + """ + self.__names.setdefault(node.target.id, []).append(node.target) + + self.generic_visit(node) + + def visit(self, node): + """ + Public method to traverse a given AST node. + + Like super-visit but supports iteration over lists. + + @param node AST node to be traversed + @type TYPE + @return reference to the last processed node + @rtype ast.Node + """ + if not isinstance(node, list): + super().visit(node) + + for elem in node: + super().visit(elem) + + return node + + def getNames(self): + """ + Public method to return the extracted names and Name nodes. + + @return dictionary containing the names as keys and the list of nodes + @rtype dict + """ + return self.__names + + +class ExceptBaseExceptionVisitor(ast.NodeVisitor): + """ + Class to determine, if a 'BaseException' is re-raised. + """ + + def __init__(self, exceptNode): + """ + Constructor + + @param exceptNode exception node to be inspected + @type ast.ExceptHandler + """ + super().__init__() + self.__root = exceptNode + self.__reRaised = False + + def reRaised(self) -> bool: + """ + Public method to check, if the exception is re-raised. + + @return flag indicating a re-raised exception + @rtype bool + """ + self.visit(self.__root) + return self.__reRaised + + def visit_Raise(self, node): + """ + Public method to handle 'Raise' nodes. + + If we find a corresponding `raise` or `raise e` where e was from + `except BaseException as e:` then we mark re_raised as True and can + stop scanning. + + @param node reference to the node to be processed + @type ast.Raise + """ + if node.exc is None or ( + isinstance(node.exc, ast.Name) and node.exc.id == self.__root.name + ): + self.__reRaised = True + return + + super().generic_visit(node) + + def visit_ExceptHandler(self, node: ast.ExceptHandler): + """ + Public method to handle 'ExceptHandler' nodes. + + @param node reference to the node to be processed + @type ast.ExceptHandler + """ + if node is not self.__root: + return # entered a nested except - stop searching + + super().generic_visit(node) + + +class FunctionDefDefaultsVisitor(ast.NodeVisitor): + """ + Class used by M506, M508 and M539. + """ + + def __init__( + self, + errorCodeCalls, # M506 or M539 + errorCodeLiterals, # M508 or M539 + ): + """ + Constructor + + @param errorCodeCalls error code for ast.Call nodes + @type str + @param errorCodeLiterals error code for literal nodes + @type str + """ + self.__errorCodeCalls = errorCodeCalls + self.__errorCodeLiterals = errorCodeLiterals + for nodeType in BugbearMutableLiterals + BugbearMutableComprehensions: + setattr( + self, f"visit_{nodeType}", self.__visitMutableLiteralOrComprehension + ) + self.errors = [] + self.__argDepth = 0 + + super().__init__() + + def __visitMutableLiteralOrComprehension(self, node): + """ + Private method to flag mutable literals and comprehensions. + + @param node AST node to be processed + @type ast.Dict, ast.List, ast.Set, ast.ListComp, ast.DictComp or ast.SetComp + """ + # Flag M506 if mutable literal/comprehension is not nested. + # We only flag these at the top level of the expression as we + # cannot easily guarantee that nested mutable structures are not + # made immutable by outer operations, so we prefer no false positives. + # e.g. + # >>> def this_is_fine(a=frozenset({"a", "b", "c"})): ... + # + # >>> def this_is_not_fine_but_hard_to_detect(a=(lambda x: x)([1, 2, 3])) + # + # We do still search for cases of B008 within mutable structures though. + if self.__argDepth == 1: + self.errors.append((node, self.__errorCodeCalls)) + + # Check for nested functions. + self.generic_visit(node) + + def visit_Call(self, node): + """ + Public method to process Call nodes. + + @param node AST node to be processed + @type ast.Call + """ + callPath = ".".join(composeCallPath(node.func)) + if callPath in BugbearMutableCalls: + self.errors.append((node, self.__errorCodeCalls)) + self.generic_visit(node) + return + + if callPath in BugbearImmutableCalls: + self.generic_visit(node) + return + + # Check if function call is actually a float infinity/NaN literal + if callPath == "float" and len(node.args) == 1: + try: + value = float(ast.literal_eval(node.args[0])) + except Exception: # secok + pass + else: + if math.isfinite(value): + self.errors.append((node, self.__errorCodeLiterals)) + else: + self.errors.append((node, self.__errorCodeLiterals)) + + # Check for nested functions. + self.generic_visit(node) + + def visit_Lambda(self, node): + """ + Public method to process Lambda nodes. + + @param node AST node to be processed + @type ast.Lambda + """ + # Don't recurse into lambda expressions + # as they are evaluated at call time. + pass + + def visit(self, node): + """ + Public method to traverse an AST node or a list of AST nodes. + + This is an extended method that can also handle a list of AST nodes. + + @param node AST node or list of AST nodes to be processed + @type ast.AST or list of ast.AST + """ + self.__argDepth += 1 + if isinstance(node, list): + for elem in node: + if elem is not None: + super().visit(elem) + else: + super().visit(node) + self.__argDepth -= 1 + + +class NameFinder(ast.NodeVisitor): + """ + Class to extract a name out of a tree of nodes. + """ + + def __init__(self): + """ + Constructor + """ + super().__init__() + + self.__names = {} + + def visit_Name(self, node): + """ + Public method to handle 'Name' nodes. + + @param node reference to the node to be processed + @type ast.Name + """ + self.__names.setdefault(node.id, []).append(node) + + def visit(self, node): + """ + Public method to traverse a given AST node. + + @param node AST node to be traversed + @type ast.Node + @return reference to the last processed node + @rtype ast.Node + """ + if isinstance(node, list): + for elem in node: + super().visit(elem) + return node + else: + return super().visit(node) + + def getNames(self): + """ + Public method to return the extracted names and Name nodes. + + @return dictionary containing the names as keys and the list of nodes + @rtype dict + """ + return self.__names + + +class M520NameFinder(NameFinder): + """ + Class to extract a name out of a tree of nodes ignoring names defined within the + local scope of a comprehension. + """ + + def visit_GeneratorExp(self, node): + """ + Public method to handle a generator expressions. + + @param node reference to the node to be processed + @type ast.GeneratorExp + """ + self.visit(node.generators) + + def visit_ListComp(self, node): + """ + Public method to handle a list comprehension. + + @param node reference to the node to be processed + @type TYPE + """ + self.visit(node.generators) + + def visit_DictComp(self, node): + """ + Public method to handle a dictionary comprehension. + + @param node reference to the node to be processed + @type TYPE + """ + self.visit(node.generators) + + def visit_comprehension(self, node): + """ + Public method to handle the 'for' of a comprehension. + + @param node reference to the node to be processed + @type ast.comprehension + """ + self.visit(node.iter) + + def visit_Lambda(self, node): + """ + Public method to handle a Lambda function. + + @param node reference to the node to be processed + @type ast.Lambda + """ + self.visit(node.body) + for lambdaArg in node.args.args: + self.getNames().pop(lambdaArg.arg, None)