eric6/Plugins/CheckerPlugins/CodeStyleChecker/Simplify/SimplifyNodeVisitor.py

changeset 8191
9125da0c227e
parent 8189
17df5c8df8c1
child 8192
e1157bd8b4c2
diff -r 17df5c8df8c1 -r 9125da0c227e eric6/Plugins/CheckerPlugins/CodeStyleChecker/Simplify/SimplifyNodeVisitor.py
--- a/eric6/Plugins/CheckerPlugins/CodeStyleChecker/Simplify/SimplifyNodeVisitor.py	Thu Apr 01 19:48:36 2021 +0200
+++ b/eric6/Plugins/CheckerPlugins/CodeStyleChecker/Simplify/SimplifyNodeVisitor.py	Fri Apr 02 15:35:44 2021 +0200
@@ -9,6 +9,7 @@
 
 import ast
 import collections
+import itertools
 
 try:
     from ast import unparse
@@ -99,6 +100,8 @@
         self.__check103(node)
         self.__check106(node)
         self.__check108(node)
+        self.__check114(node)
+        self.__check116(node)
         
         self.generic_visit(node)
     
@@ -111,6 +114,7 @@
         """
         self.__check104(node)
         self.__check110_111(node)
+        self.__check113(node)
         
         self.generic_visit(node)
     
@@ -126,6 +130,17 @@
         
         self.generic_visit(node)
     
+    def visit_Call(self, node):
+        """
+        Public method to process a Call node.
+        
+        @param node reference to the Call node
+        @type ast.Call
+        """
+        self.__check115(node)
+        
+        self.generic_visit(node)
+    
     #############################################################
     ## Helper methods for the various checkers below
     #############################################################
@@ -155,6 +170,91 @@
         
         return [name for name, count in counter.items() if count > 1]
     
+    def __isConstantIncrease(self, expression):
+        """
+        Private method check the given expression for being a constant
+        increase.
+        
+        @param expression reference to the expression node
+        @type ast.AugAssign
+        @return flag indicating a constant increase
+        @rtype bool
+        """
+        return (
+            isinstance(expression.op, ast.Add) and
+            isinstance(expression.value, (ast.Constant, ast.Num))
+        )
+    
+    def __getIfBodyPairs(self, node):
+        """
+        Private method to extract a list of pairs of test and body for an
+        If node.
+        
+        @param node reference to the If node to be processed
+        @type ast.If
+        @return list of pairs of test and body
+        @rtype list of tuples of (ast.expr, [ast.stmt])
+        """
+        pairs = [(node.test, node.body)]
+        orelse = node.orelse
+        while (
+            isinstance(orelse, list) and
+            len(orelse) == 1 and
+            isinstance(orelse[0], ast.If)
+        ):
+            pairs.append((orelse[0].test, orelse[0].body))
+            orelse = orelse[0].orelse
+        return pairs
+    
+    def __isSameBody(self, body1, body2):
+        """
+        Private method check, if the given bodies are equivalent.
+        
+        @param body1 list of statements of the first body
+        @type list of ast.stmt
+        @param body2 list of statements of the second body
+        @type list of ast.stmt
+        @return flag indicating identical bodies
+        @rtype bool
+        """
+        if len(body1) != len(body2):
+            return False
+        for a, b in zip(body1, body2):
+            try:
+                statementEqual = self.__isStatementEqual(a, b)
+            except RecursionError:  # maximum recursion depth
+                statementEqual = False
+            if not statementEqual:
+                return False
+        return True
+    
+    def __isStatementEqual(self, a: ast.stmt, b: ast.stmt) -> bool:
+        """
+        Private method to check, if two statements are equal.
+        
+        @param a reference to the first statement
+        @type ast.stmt
+        @param b reference to the second statement
+        @type ast.stmt
+        @return flag indicating if the two statements are equal
+        @rtype bool
+        """
+        if type(a) is not type(b):
+            return False
+        
+        if isinstance(a, ast.AST):
+            for k, v in vars(a).items():
+                if k in ("lineno", "col_offset", "ctx", "end_lineno",
+                         "parent"):
+                    continue
+                if not self.__isStatementEqual(v, getattr(b, k)):
+                    return False
+            return True
+        elif isinstance(a, list):
+            return all(itertools.starmap(self.__isStatementEqual, zip(a, b)))
+        else:
+            return a == b
+    
     #############################################################
     ## Methods to check for possible code simplifications below
     #############################################################
@@ -429,12 +529,12 @@
         #         return False
         # return True
         if (
-            len(node.body) == 1
-            and isinstance(node.body[0], ast.If)
-            and len(node.body[0].body) == 1
-            and isinstance(node.body[0].body[0], ast.Return)
-            and isinstance(node.body[0].body[0].value, BOOL_CONST_TYPES)
-            and hasattr(node.body[0].body[0].value, "value")
+            len(node.body) == 1 and
+            isinstance(node.body[0], ast.If) and
+            len(node.body[0].body) == 1 and
+            isinstance(node.body[0].body[0], ast.Return) and
+            isinstance(node.body[0].body[0].value, BOOL_CONST_TYPES) and
+            hasattr(node.body[0].body[0].value, "value")
         ):
             check = unparse(node.body[0].test)
             target = unparse(node.target)
@@ -451,7 +551,7 @@
     
     def __check112(self, node):
         """
-        Public method to check for non-capitalized calls to environment
+        Private method to check for non-capitalized calls to environment
         variables.
         
         @param node reference to the AST node to be checked
@@ -460,17 +560,17 @@
         # os.environ["foo"]
         # os.environ.get("bar")
         isIndexCall = (
-            isinstance(node.value, ast.Subscript)
-            and isinstance(node.value.value, ast.Attribute)
-            and isinstance(node.value.value.value, ast.Name)
-            and node.value.value.value.id == "os"
-            and node.value.value.attr == "environ"
-            and (
+            isinstance(node.value, ast.Subscript) and
+            isinstance(node.value.value, ast.Attribute) and
+            isinstance(node.value.value.value, ast.Name) and
+            node.value.value.value.id == "os" and
+            node.value.value.attr == "environ" and
+            (
                 (
-                    isinstance(node.value.slice, ast.Index)
-                    and isinstance(node.value.slice.value, STR_TYPES)
-                )
-                or isinstance(node.value.slice, ast.Constant)
+                    isinstance(node.value.slice, ast.Index) and
+                    isinstance(node.value.slice.value, STR_TYPES)
+                ) or
+                isinstance(node.value.slice, ast.Constant)
             )
         )
         if isIndexCall:
@@ -491,15 +591,15 @@
             hasChange = envName != envName.upper()
 
         isGetCall = (
-            isinstance(node.value, ast.Call)
-            and isinstance(node.value.func, ast.Attribute)
-            and isinstance(node.value.func.value, ast.Attribute)
-            and isinstance(node.value.func.value.value, ast.Name)
-            and node.value.func.value.value.id == "os"
-            and node.value.func.value.attr == "environ"
-            and node.value.func.attr == "get"
-            and len(node.value.args) in [1, 2]
-            and isinstance(node.value.args[0], STR_TYPES)
+            isinstance(node.value, ast.Call) and
+            isinstance(node.value.func, ast.Attribute) and
+            isinstance(node.value.func.value, ast.Attribute) and
+            isinstance(node.value.func.value.value, ast.Name) and
+            node.value.func.value.value.id == "os" and
+            node.value.func.value.attr == "environ" and
+            node.value.func.attr == "get" and
+            len(node.value.args) in [1, 2] and
+            isinstance(node.value.args[0], STR_TYPES)
         )
         if isGetCall:
             call = node.value
@@ -529,6 +629,164 @@
         
         self.__error(node.lineno - 1, node.col_offset, "Y112", expected,
                      original)
+    
+    def __check113(self, node):
+        """
+        Private method to check for loops in which "enumerate" should be
+        used.
+        
+        @param node reference to the AST node to be checked
+        @type ast.For
+        """
+        # idx = 0
+        # for el in iterable:
+        #     ...
+        #     idx += 1
+        variableCandidates = []
+        for expression in node.body:
+            if (
+                isinstance(expression, ast.AugAssign) and
+                self.__isConstantIncrease(expression) and
+                isinstance(expression.target, ast.Name)
+            ):
+                variableCandidates.append(expression.target)
+
+        for candidate in variableCandidates:
+            self.__error(candidate.lineno - 1, candidate.col_offset, "Y113",
+                         unparse(candidate))
+    
+    def __check114(self, node):
+        """
+        Private method to check for alternative if clauses with identical
+        bodies.
+        
+        @param node reference to the AST node to be checked
+        @type ast.If
+        """
+        # if a:
+        #     b
+        # elif c:
+        #     b
+        ifBodyPairs = self.__getIfBodyPairs(node)
+        errorPairs = []
+        for ifbody1, ifbody2 in itertools.combinations(ifBodyPairs, 2):
+            if self.__isSameBody(ifbody1[1], ifbody2[1]):
+                errorPairs.append((ifbody1, ifbody2))
+        for ifbody1, ifbody2 in errorPairs:
+            self.__error(ifbody1[0].lineno - 1, ifbody1[0].col_offset, "Y114",
+                         unparse(ifbody1[0]), unparse(ifbody2[0]))
+    
+    def __check115(self, node):
+        """
+        Private method to to check for places where open() is called without
+        a context handler.
+        
+        @param node reference to the AST node to be checked
+        @type ast.Call
+        """
+        # f = open(...)
+        #. ..  # (do something with f)
+        # f.close()
+        if (
+            isinstance(node.func, ast.Name) and
+            node.func.id == "open" and
+            not isinstance(node.parent, ast.withitem)
+        ):
+            self.__error(node.lineno - 1, node.col_offset, "Y115")
+    
+    def __check116(self, node):
+        """
+        Private method to check for places with 3 or more consecutive
+        if-statements with direct returns.
+        
+        * Each if-statement must be a check for equality with the
+          same variable
+        * Each if-statement must just have a "return"
+        * Else must also just have a return
+        
+        @param node reference to the AST node to be checked
+        @type ast.If
+        """
+        # if a == "foo":
+        #     return "bar"
+        # elif a == "bar":
+        #     return "baz"
+        # elif a == "boo":
+        #     return "ooh"
+        # else:
+        #    return 42
+        if (
+            isinstance(node.test, ast.Compare) and
+            isinstance(node.test.left, ast.Name) and
+            len(node.test.ops) == 1 and
+            isinstance(node.test.ops[0], ast.Eq) and
+            len(node.test.comparators) == 1 and
+            isinstance(node.test.comparators[0], AST_CONST_TYPES) and
+            len(node.body) == 1 and
+            isinstance(node.body[0], ast.Return) and
+            len(node.orelse) == 1 and
+            isinstance(node.orelse[0], ast.If)
+        ):
+            variable = node.test.left
+            child = node.orelse[0]
+            elseValue = None
+            if isinstance(node.test.comparators[0], ast.Str):
+                keyValuePairs = {
+                    node.test.comparators[0].s:
+                        unparse(node.body[0].value).strip("'")
+                }
+            elif isinstance(node.test.comparators[0], ast.Num):
+                keyValuePairs = {
+                    node.test.comparators[0].n:
+                        unparse(node.body[0].value).strip("'")
+                }
+            else:
+                keyValuePairs = {
+                    node.test.comparators[0].value:
+                        unparse(node.body[0].value).strip("'")
+                }
+            while child:
+                if not (
+                    isinstance(child.test, ast.Compare) and
+                    isinstance(child.test.left, ast.Name) and
+                    child.test.left.id == variable.id and
+                    len(child.test.ops) == 1 and
+                    isinstance(child.test.ops[0], ast.Eq) and
+                    len(child.test.comparators) == 1 and
+                    isinstance(child.test.comparators[0], AST_CONST_TYPES) and
+                    len(child.body) == 1 and
+                    isinstance(child.body[0], ast.Return) and
+                    len(child.orelse) <= 1
+                ):
+                    return
+                
+                if isinstance(child.test.comparators[0], ast.Str):
+                    key = child.test.comparators[0].s
+                elif isinstance(child.test.comparators[0], ast.Num):
+                    key = child.test.comparators[0].n
+                else:
+                    key = child.test.comparators[0].value
+                keyValuePairs[key] = unparse(child.body[0].value).strip("'")
+                if len(child.orelse) == 1:
+                    if isinstance(child.orelse[0], ast.If):
+                        child = child.orelse[0]
+                    elif isinstance(child.orelse[0], ast.Return):
+                        elseValue = unparse(child.orelse[0].value)
+                        child = None
+                    else:
+                        return
+                else:
+                    child = None
+            
+            if len(keyValuePairs) < 3:
+                return
+            
+            if elseValue:
+                ret = f"{keyValuePairs}.get({variable.id}, {elseValue})"
+            else:
+                ret = f"{keyValuePairs}.get({variable.id})"
+            
+            self.__error(node.lineno - 1, node.col_offset, "Y116", ret)
 
 #
 # eflag: noqa = M891

eric ide

mercurial