src/eric7/Plugins/CheckerPlugins/CodeStyleChecker/Miscellaneous/MiscellaneousChecker.py

branch
eric7
changeset 10999
c3cf24fe9113
parent 10997
d470b58626d2
child 11000
f8371a2dd08f
diff -r 6d7bddfde5fe -r c3cf24fe9113 src/eric7/Plugins/CheckerPlugins/CodeStyleChecker/Miscellaneous/MiscellaneousChecker.py
--- a/src/eric7/Plugins/CheckerPlugins/CodeStyleChecker/Miscellaneous/MiscellaneousChecker.py	Tue Oct 22 16:50:36 2024 +0200
+++ b/src/eric7/Plugins/CheckerPlugins/CodeStyleChecker/Miscellaneous/MiscellaneousChecker.py	Tue Oct 22 17:22:53 2024 +0200
@@ -12,11 +12,13 @@
 import contextlib
 import copy
 import itertools
+import math
 import re
 import sys
 import tokenize
 
 from collections import defaultdict, namedtuple
+from dataclasses import dataclass
 from keyword import iskeyword
 from string import Formatter
 
@@ -39,6 +41,36 @@
 from .MiscellaneousDefaults import MiscellaneousCheckerDefaultArgs
 
 
+BugbearMutableLiterals = ("Dict", "List", "Set")
+BugbearMutableComprehensions = ("ListComp", "DictComp", "SetComp")
+BugbearMutableCalls = (
+    "Counter",
+    "OrderedDict",
+    "collections.Counter",
+    "collections.OrderedDict",
+    "collections.defaultdict",
+    "collections.deque",
+    "defaultdict",
+    "deque",
+    "dict",
+    "list",
+    "set",
+)
+BugbearImmutableCalls = (
+    "tuple",
+    "frozenset",
+    "types.MappingProxyType",
+    "MappingProxyType",
+    "re.compile",
+    "operator.attrgetter",
+    "operator.itemgetter",
+    "operator.methodcaller",
+    "attrgetter",
+    "itemgetter",
+    "methodcaller",
+)
+
+
 def composeCallPath(node):
     """
     Generator function to assemble the call path of a given node.
@@ -51,6 +83,8 @@
     if isinstance(node, ast.Attribute):
         yield from composeCallPath(node.value)
         yield node.attr
+    elif isinstance(node, ast.Call):
+        yield from composeCallPath(node.func)
     elif isinstance(node, ast.Name):
         yield node.id
 
@@ -140,7 +174,9 @@
         "M503",
         "M504",
         "M505",
+        "M506",
         "M507",
+        "M508",
         "M509",
         "M510",
         "M511",
@@ -170,6 +206,8 @@
         "M535",
         "M536",
         "M537",
+        "M539",
+        "M540",
         ## Bugbear, opininonated
         "M569",
         ## Bugbear++
@@ -196,9 +234,6 @@
         "M801",
         ## one element tuple
         "M811",
-        ## Mutable Defaults
-        "M821",
-        "M822",
         ## return statements
         "M831",
         "M832",
@@ -352,7 +387,9 @@
                     "M503",
                     "M504",
                     "M505",
+                    "M506",
                     "M507",
+                    "M508",
                     "M509",
                     "M510",
                     "M511",
@@ -382,6 +419,8 @@
                     "M535",
                     "M536",
                     "M537",
+                    "M539",
+                    "M540",
                     "M569",
                     "M581",
                     "M582",
@@ -407,7 +446,6 @@
             (self.__checkGettext, ("M711",)),
             (self.__checkPrintStatements, ("M801",)),
             (self.__checkTuple, ("M811",)),
-            (self.__checkMutableDefault, ("M821", "M822")),
             (self.__checkReturn, ("M831", "M832", "M833", "M834")),
             (self.__checkLineContinuation, ("M841",)),
             (self.__checkImplicitStringConcat, ("M851", "M852")),
@@ -1186,80 +1224,6 @@
                         compType[node.__class__],
                     )
 
-    def __checkMutableDefault(self):
-        """
-        Private method to check for use of mutable types as default arguments.
-        """
-        mutableTypes = (
-            ast.Call,
-            ast.Dict,
-            ast.List,
-            ast.Set,
-            ast.DictComp,
-            ast.ListComp,
-            ast.SetComp,
-        )
-        mutableCalls = (
-            "Counter",
-            "OrderedDict",
-            "collections.Counter",
-            "collections.OrderedDict",
-            "collections.defaultdict",
-            "collections.deque",
-            "defaultdict",
-            "deque",
-            "dict",
-            "list",
-            "set",
-        )
-        immutableCalls = (
-            "tuple",
-            "frozenset",
-            "types.MappingProxyType",
-            "MappingProxyType",
-            "re.compile",
-            "operator.attrgetter",
-            "operator.itemgetter",
-            "operator.methodcaller",
-            "attrgetter",
-            "itemgetter",
-            "methodcaller",
-        )
-        functionDefs = [ast.FunctionDef]
-        with contextlib.suppress(AttributeError):
-            functionDefs.append(ast.AsyncFunctionDef)
-
-        for node in ast.walk(self.__tree):
-            if any(isinstance(node, functionDef) for functionDef in functionDefs):
-                defaults = node.args.defaults[:]
-                with contextlib.suppress(AttributeError):
-                    defaults += node.args.kw_defaults[:]
-                for default in defaults:
-                    if any(
-                        isinstance(default, mutableType) for mutableType in mutableTypes
-                    ):
-                        typeName = type(default).__name__
-                        if isinstance(default, ast.Call):
-                            callPath = ".".join(composeCallPath(default.func))
-                            if callPath in mutableCalls:
-                                self.__error(
-                                    default.lineno - 1,
-                                    default.col_offset,
-                                    "M823",
-                                    callPath + "()",
-                                )
-                            elif callPath not in immutableCalls:
-                                self.__error(
-                                    default.lineno - 1,
-                                    default.col_offset,
-                                    "M822",
-                                    typeName,
-                                )
-                        else:
-                            self.__error(
-                                default.lineno - 1, default.col_offset, "M821", typeName
-                            )
-
     def __dictShouldBeChecked(self, node):
         """
         Private function to test, if the node should be checked.
@@ -1676,13 +1640,21 @@
 #######################################################################
 ## BugBearVisitor
 ##
-## adapted from: flake8-bugbear v24.4.26
+## adapted from: flake8-bugbear v24.8.19
 ##
 ## Original: Copyright (c) 2016 Ɓukasz Langa
 #######################################################################
 
 BugBearContext = namedtuple("BugBearContext", ["node", "stack"])
 
+@dataclass
+class M540CaughtException:
+    """
+    Class to hold the data for a caught exception.
+    """
+    name : str
+    hasNote: bool
+
 
 class BugBearVisitor(ast.NodeVisitor):
     """
@@ -1721,6 +1693,7 @@
 
         self.__M523Seen = set()
         self.__M505Imports = set()
+        self.__M540CaughtException = None
 
     @property
     def nodeStack(self):
@@ -1755,23 +1728,6 @@
             is not None
         )
 
-    def __composeCallPath(self, node):
-        """
-        Private method get the individual elements of the call path of a node.
-
-        @param node reference to the node
-        @type ast.Node
-        @yield one element of the call path
-        @ytype ast.Node
-        """
-        if isinstance(node, ast.Attribute):
-            yield from self.__composeCallPath(node.value)
-            yield node.attr
-        elif isinstance(node, ast.Call):
-            yield from self.__composeCallPath(node.func)
-        elif isinstance(node, ast.Name):
-            yield node.id
-
     def toNameStr(self, node):
         """
         Public method to turn Name and Attribute nodes to strings, handling any
@@ -2077,47 +2033,32 @@
         """
         if node.type is None:
             # bare except is handled by pycodestyle already
-            pass
-
+            self.generic_visit(node)
+            return
+
+        oldM540CaughtException = self.__M540CaughtException
+        if node.name is None:
+            self.__M540CaughtException = None
         else:
-            handlers = self.__flattenExcepthandler(node.type)
-            names = []
-            badHandlers = []
-            ignoredHandlers = []
-            for handler in handlers:
-                if isinstance(handler, (ast.Name, ast.Attribute)):
-                    name = self.toNameStr(handler)
-                    if name is None:
-                        ignoredHandlers.append(handler)
-                    else:
-                        names.append(name)
-                elif isinstance(handler, (ast.Call, ast.Starred)):
-                    ignoredHandlers.append(handler)
-                else:
-                    badHandlers.append(handler)
-            if badHandlers:
-                self.violations.append((node, "M530"))
-            if len(names) == 0 and not badHandlers and not ignoredHandlers:
-                self.violations.append((node, "M529"))
-            elif (
-                len(names) == 1
-                and not badHandlers
-                and not ignoredHandlers
-                and isinstance(node.type, ast.Tuple)
-            ):
-                self.violations.append((node, "M513", *names))
-            else:
-                maybeError = self.__checkRedundantExcepthandlers(names, node)
-                if maybeError is not None:
-                    self.violations.append(maybeError)
-            if (
-                "BaseException" in names
-                and not ExceptBaseExceptionVisitor(node).reRaised()
-            ):
-                self.violations.append((node, "M536"))
+            self.__M540CaughtException = M540CaughtException(node.name, False)
+
+        names = self.__checkForM513_M529_M530(node)
+
+        if (
+            "BaseException" in names
+            and not ExceptBaseExceptionVisitor(node).reRaised()
+        ):
+            self.violations.append((node, "M536"))
 
         self.generic_visit(node)
 
+        if (
+            self.__M540CaughtException is not None
+            and self.__M540CaughtException.hasNote
+        ):
+            self.violations.append((node, "M540"))
+        self.__M540CaughtException = oldM540CaughtException
+
     def visit_UAdd(self, node):
         """
         Public method to handle unary additions.
@@ -2139,8 +2080,11 @@
         @param node reference to the node to be processed
         @type ast.Call
         """
+        isM540AddNote = False
+
         if isinstance(node.func, ast.Attribute):
             self.__checkForM505(node)
+            isM540AddNote = self.__checkForM540AddNote(node.func)
         else:
             with contextlib.suppress(AttributeError, IndexError):
                 # bad super() call
@@ -2180,9 +2124,21 @@
 
         self.__checkForM528(node)
         self.__checkForM534(node)
+        self.__checkForM539(node)
+
+        # no need for copying, if used in nested calls it will be set to None
+        currentM540CaughtException = self.__M540CaughtException
+        if not isM540AddNote:
+            self.__checkForM540Usage(node.args)
+            self.__checkForM540Usage(node.keywords)
 
         self.generic_visit(node)
 
+        if isM540AddNote:
+            # Avoid nested calls within the parameter list using the variable itself.
+            # e.g. `e.add_note(str(e))`
+            self.__M540CaughtException = currentM540CaughtException
+
     def visit_Module(self, node):
         """
         Public method to handle a module node.
@@ -2199,6 +2155,7 @@
         @param node reference to the node to be processed
         @type ast.Assign
         """
+        self.__checkForM540Usage(node.value)
         if len(node.targets) == 1:
             target = node.targets[0]
             if (
@@ -2310,6 +2267,17 @@
 
         self.generic_visit(node)
 
+    def visit_AsyncFunctionDef(self, node):
+        """
+        Public method to handle async function definitions.
+
+        @param node reference to the node to be processed
+        @type ast.AsyncFunctionDef
+        """
+        self.__checkForM506_M508(node)
+
+        self.generic_visit(node)
+
     def visit_FunctionDef(self, node):
         """
         Public method to handle function definitions.
@@ -2317,6 +2285,7 @@
         @param node reference to the node to be processed
         @type ast.FunctionDef
         """
+        self.__checkForM506_M508(node)
         self.__checkForM519(node)
         self.__checkForM521(node)
 
@@ -2330,7 +2299,7 @@
         @type ast.ClassDef
         """
         self.__checkForM521(node)
-        self.__checkForM524AndM527(node)
+        self.__checkForM524_M527(node)
 
         self.generic_visit(node)
 
@@ -2364,6 +2333,11 @@
         @param node reference to the node to be processed
         @type ast.Raise
         """
+        if node.exc is None:
+            self.__M540CaughtException = None
+        else:
+            self.__checkForM540Usage(node.exc)
+            self.__checkForM540Usage(node.cause)
         self.__checkForM516(node)
 
         self.generic_visit(node)
@@ -2401,6 +2375,7 @@
         @type ast.AnnAssign
         """
         self.__checkForM532(node)
+        self.__checkForM540Usage(node.value)
 
         self.generic_visit(node)
 
@@ -2448,7 +2423,7 @@
         elif isinstance(node, ast.ImportFrom):
             for name in node.names:
                 self.__M505Imports.add(f"{node.module}.{name.name or name.asname}")
-        elif isinstance(node, ast.Call):
+        elif isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
             if node.func.attr not in ("lstrip", "rstrip", "strip"):
                 return  # method name doesn't match
 
@@ -2470,6 +2445,17 @@
 
             self.violations.append((node, "M505"))
 
+    def __checkForM506_M508(self, node):
+        """
+        Private method to check the use of mutable literals, comprehensions and calls.
+
+        @param node reference to the node to be processed
+        @type ast.AsyncFunctionDef or ast.FunctionDef
+        """
+        visitor = FunctionDefDefaultsVisitor("M506", "M508")
+        visitor.visit(node.args.defaults + node.args.kw_defaults)
+        self.violations.extend(visitor.errors)
+
     def __checkForM507(self, node):
         """
         Private method to check for unused loop variables.
@@ -2512,6 +2498,46 @@
         for child in node.finalbody:
             _loop(child, (ast.Return, ast.Continue, ast.Break))
 
+    def __checkForM513_M529_M530(self, node):
+        """
+        Private method to check various exception handler situations.
+
+        @param node reference to the node to be processed
+        @type ast.ExceptHandler
+        """
+        handlers = self.__flattenExcepthandler(node.type)
+        names = []
+        badHandlers = []
+        ignoredHandlers = []
+
+        for handler in handlers:
+            if isinstance(handler, (ast.Name, ast.Attribute)):
+                name = self.toNameStr(handler)
+                if name is None:
+                    ignoredHandlers.append(handler)
+                else:
+                    names.append(name)
+            elif isinstance(handler, (ast.Call, ast.Starred)):
+                ignoredHandlers.append(handler)
+            else:
+                badHandlers.append(handler)
+        if badHandlers:
+            self.violations.append((node, "M530"))
+        if len(names) == 0 and not badHandlers and not ignoredHandlers:
+            self.violations.append((node, "M529"))
+        elif (
+            len(names) == 1
+            and not badHandlers
+            and not ignoredHandlers
+            and isinstance(node.type, ast.Tuple)
+        ):
+            self.violations.append((node, "M513", *names))
+        else:
+            maybeError = self.__checkRedundantExcepthandlers(names, node)
+            if maybeError is not None:
+                self.violations.append(maybeError)
+        return names
+
     def __checkForM515(self, node):
         """
         Private method to check for pointless comparisons.
@@ -2625,7 +2651,7 @@
         # Preserve decorator order so we can get the lineno from the decorator node
         # rather than the function node (this location definition changes in Python 3.8)
         resolvedDecorators = (
-            ".".join(self.__composeCallPath(decorator))
+            ".".join(composeCallPath(decorator))
             for decorator in node.decorator_list
         )
         for idx, decorator in enumerate(resolvedDecorators):
@@ -2764,7 +2790,7 @@
             if reassignedInLoop.issuperset(err[2]):
                 self.violations.append((err[3], "M523", err[2]))
 
-    def __checkForM524AndM527(self, node):
+    def __checkForM524_M527(self, node):
         """
         Private method to check for inheritance from abstract classes in abc and lack of
         any methods decorated with abstract*.
@@ -2859,13 +2885,13 @@
 
         for handler in node.handlers:
             if isinstance(handler.type, (ast.Name, ast.Attribute)):
-                name = ".".join(self.__composeCallPath(handler.type))
+                name = ".".join(composeCallPath(handler.type))
                 seen.append(name)
             elif isinstance(handler.type, ast.Tuple):
                 # to avoid checking the same as M514, remove duplicates per except
                 uniques = set()
                 for entry in handler.type.elts:
-                    name = ".".join(self.__composeCallPath(entry))
+                    name = ".".join(composeCallPath(entry))
                     uniques.add(name)
                 seen.extend(uniques)
 
@@ -3006,16 +3032,18 @@
         """
         if not isinstance(node.func, ast.Attribute):
             return
-        if not isinstance(node.func.value, ast.Name) or node.func.value.id != "re":
+        func = node.func
+        if not isinstance(func.value, ast.Name) or func.value.id != "re":
             return
 
         def check(numArgs, paramName):
             if len(node.args) > numArgs:
-                self.violations.append((node, "M534", node.func.attr, paramName))
-
-        if node.func.attr in ("sub", "subn"):
+                arg = node.args[numArgs]
+                self.violations.append((arg, "M534", func.attr, paramName))
+
+        if func.attr in ("sub", "subn"):
             check(3, "count")
-        elif node.func.attr == "split":
+        elif func.attr == "split":
             check(2, "maxsplit")
 
     def __checkForM535(self, node):
@@ -3035,6 +3063,85 @@
         ) and node.key.id not in self.__getDictCompLoopAndNamedExprVarNames(node):
             self.violations.append((node, "M535", node.key.id))
 
+    def __checkForM539(self, node):
+        """
+        Private method to check for correct ContextVar usage.
+
+        @param node reference to the node to be processed
+        @type ast.Call
+        """
+        if not (
+            (isinstance(node.func, ast.Name) and node.func.id == "ContextVar")
+            or (
+                isinstance(node.func, ast.Attribute)
+                and node.func.attr == "ContextVar"
+                and isinstance(node.func.value, ast.Name)
+                and node.func.value.id == "contextvars"
+            )
+        ):
+            return
+
+        # ContextVar only takes one kw currently, but better safe than sorry
+        for kw in node.keywords:
+            if kw.arg == "default":
+                break
+        else:
+            return
+
+        visitor = FunctionDefDefaultsVisitor("M539", "M539")
+        visitor.visit(kw.value)
+        self.violations.extend(visitor.errors)
+
+    def __checkForM540AddNote(self, node):
+        """
+        Private method to check add_note usage.
+
+        @param node reference to the node to be processed
+        @type ast.Attribute
+        @return flag
+        @rtype bool
+        """
+        if (
+            node.attr == "add_note"
+            and isinstance(node.value, ast.Name)
+            and self.__M540CaughtException
+            and node.value.id == self.__M540CaughtException.name
+        ):
+            self.__M540CaughtException.hasNote = True
+            return True
+
+        return False
+
+    def __checkForM540Usage(self, node):
+        """
+        Private method to check the usage of exceptions with added note.
+
+        @param node reference to the node to be processed
+        @type ast.expr or None
+        """
+        def superwalk(node: ast.AST | list[ast.AST]):
+            """
+            Function to walk an AST node or a list of AST nodes.
+
+            @param node reference to the node or a list of nodes to be processed
+            @type ast.AST or list[ast.AST]
+            @yield next node to be processed
+            @ytype ast.AST
+            """
+            if isinstance(node, list):
+                for n in node:
+                    yield from ast.walk(n)
+            else:
+                yield from ast.walk(node)
+
+        if not self.__M540CaughtException or node is None:
+            return
+
+        for n in superwalk(node):
+            if isinstance(n, ast.Name) and n.id == self.__M540CaughtException.name:
+                self.__M540CaughtException = None
+                break
+
     def __checkForM569(self, node):
         """
         Private method to check for changes to a loop's mutable iterable.
@@ -3303,6 +3410,120 @@
         return self.__names
 
 
+class FunctionDefDefaultsVisitor(ast.NodeVisitor):
+    """
+    Class used by M506, M508 and M539.
+    """
+
+    def __init__(
+        self,
+        errorCodeCalls,  # M506 or M539
+        errorCodeLiterals,  # M508 or M539
+    ):
+        """
+        Constructor
+
+        @param errorCodeCalls error code for ast.Call nodes
+        @type str
+        @param errorCodeLiterals error code for literal nodes
+        @type str
+        """
+        self.__errorCodeCalls = errorCodeCalls
+        self.__errorCodeLiterals = errorCodeLiterals
+        for nodeType in BugbearMutableLiterals + BugbearMutableComprehensions:
+            setattr(
+                self, f"visit_{nodeType}", self.__visitMutableLiteralOrComprehension
+            )
+        self.errors = []
+        self.__argDepth = 0
+
+        super().__init__()
+
+    def __visitMutableLiteralOrComprehension(self, node):
+        """
+        Private method to flag mutable literals and comprehensions.
+
+        @param node AST node to be processed
+        @type ast.Dict, ast.List, ast.Set, ast.ListComp, ast.DictComp or ast.SetComp
+        """
+        # Flag M506 if mutable literal/comprehension is not nested.
+        # We only flag these at the top level of the expression as we
+        # cannot easily guarantee that nested mutable structures are not
+        # made immutable by outer operations, so we prefer no false positives.
+        # e.g.
+        # >>> def this_is_fine(a=frozenset({"a", "b", "c"})): ...
+        #
+        # >>> def this_is_not_fine_but_hard_to_detect(a=(lambda x: x)([1, 2, 3]))
+        #
+        # We do still search for cases of B008 within mutable structures though.
+        if self.__argDepth == 1:
+            self.errors.append((node, self.__errorCodeCalls))
+
+        # Check for nested functions.
+        self.generic_visit(node)
+
+    def visit_Call(self, node):
+        """
+        Public method to process Call nodes.
+
+        @param node AST node to be processed
+        @type ast.Call
+        """
+        callPath = ".".join(composeCallPath(node.func))
+        if callPath in BugbearMutableCalls:
+            self.errors.append((node, self.__errorCodeCalls))
+            self.generic_visit(node)
+            return
+
+        if callPath in BugbearImmutableCalls:
+            self.generic_visit(node)
+            return
+
+        # Check if function call is actually a float infinity/NaN literal
+        if callPath == "float" and len(node.args) == 1:
+            try:
+                value = float(ast.literal_eval(node.args[0]))
+            except Exception:
+                pass
+            else:
+                if math.isfinite(value):
+                    self.errors.append((node, self.__errorCodeLiterals))
+        else:
+            self.errors.append((node, self.__errorCodeLiterals))
+
+        # Check for nested functions.
+        self.generic_visit(node)
+
+    def visit_Lambda(self, node):
+        """
+        Public method to process Lambda nodes.
+
+        @param node AST node to be processed
+        @type ast.Lambda
+        """
+        # Don't recurse into lambda expressions
+        # as they are evaluated at call time.
+        pass
+
+    def visit(self, node):
+        """
+        Public method to traverse an AST node or a list of AST nodes.
+
+        This is an extended method that can also handle a list of AST nodes.
+
+        @param node AST node or list of AST nodes to be processed
+        @type ast.AST or list of ast.AST
+        """
+        self.__argDepth += 1
+        if isinstance(node, list):
+            for elem in node:
+                if elem is not None:
+                    super().visit(elem)
+        else:
+            super().visit(node)
+        self.__argDepth -= 1
+
+
 class M520NameFinder(NameFinder):
     """
     Class to extract a name out of a tree of nodes ignoring names defined within the

eric ide

mercurial