Plugins/CheckerPlugins/CodeStyleChecker/MiscellaneousChecker.py

changeset 6183
29384109306c
parent 6182
f293e95b914d
child 6184
789e88d94899
--- a/Plugins/CheckerPlugins/CodeStyleChecker/MiscellaneousChecker.py	Sun Mar 11 16:11:31 2018 +0100
+++ b/Plugins/CheckerPlugins/CodeStyleChecker/MiscellaneousChecker.py	Sun Mar 11 19:38:33 2018 +0100
@@ -14,6 +14,23 @@
 from string import Formatter
 
 
+def composeCallPath(node):
+    """
+    Generator function to assemble the call path of a given node.
+    
+    @param node node to assemble call path for
+    @type ast.Node
+    @return call path components
+    @rtype str
+    """
+    if isinstance(node, ast.Attribute):
+        for v in composeCallPath(node.value):
+            yield v
+        yield node.attr
+    elif isinstance(node, ast.Name):
+        yield node.id
+
+
 class MiscellaneousChecker(object):
     """
     Class implementing a checker for miscellaneous checks.
@@ -28,6 +45,9 @@
         
         "M201",
         
+        "M501", "M502", "M503", "M504", "M505", 
+        "M511", "M512", "M513", 
+        
         "M601",
         "M611", "M612", "M613",
         "M621", "M622", "M623", "M624", "M625",
@@ -110,6 +130,7 @@
             (self.__checkFormatString, ("M611", "M612", "M613",
                                         "M621", "M622", "M623", "M624", "M625",
                                         "M631", "M632")),
+            (self.__checkBugBear, ("M501", "M502", "M503", "M504", "M505", "M511", "M512", "M513",)),
             (self.__checkLogging, ("M651", "M652", "M653", "M654", "M655")),
             (self.__checkFuture, ("M701", "M702")),
             (self.__checkGettext, ("M711",)),
@@ -600,12 +621,25 @@
         """
         Private method to check for use of mutable types as default arguments.
         """
-        mutableTypes = [
+        mutableTypes = (
             ast.Call,
             ast.Dict,
             ast.List,
             ast.Set,
-        ]
+        )
+        mutableCalls = (
+            "Counter",
+            "OrderedDict",
+            "collections.Counter",
+            "collections.OrderedDict",
+            "collections.defaultdict",
+            "collections.deque",
+            "defaultdict",
+            "deque",
+            "dict",
+            "list",
+            "set",
+        )
         
         for node in ast.walk(self.__tree):
             if isinstance(node, ast.FunctionDef):
@@ -614,11 +648,19 @@
                            for mutableType in mutableTypes):
                         typeName = type(default).__name__
                         if isinstance(default, ast.Call):
-                            errorCode = "M822"
+                            callPath = '.'.join(composeCallPath(default.func))
+                            if callPath in mutableCalls:
+                                self.__error(default.lineno - 1,
+                                             default.col_offset,
+                                             "M823", callPath + "()")
+                            else:
+                                self.__error(default.lineno - 1,
+                                             default.col_offset,
+                                             "M822", typeName)
                         else:
-                            errorCode = "M821"
-                        self.__error(default.lineno - 1, default.col_offset,
-                                     errorCode, typeName)
+                            self.__error(default.lineno - 1,
+                                         default.col_offset,
+                                         "M821", typeName)
     
     def __dictShouldBeChecked(self, node):
         """
@@ -667,6 +709,15 @@
                any(name.asname == '_' for name in node.names):
                 self.__error(node.lineno - 1, node.col_offset, "M711",
                              node.names[0].name)
+    
+    def __checkBugBear(self):
+        """
+        Private method to bugbear checks.
+        """
+        visitor = BugBearVisitor()
+        visitor.visit(self.__tree)
+        for node, reason in visitor.violations:
+            self.__error(node.lineno - 1, node.col_offset, reason)
 
 
 class TextVisitor(ast.NodeVisitor):
@@ -979,5 +1030,138 @@
                         
                         super(LoggingVisitor, self).generic_visit(node)
 
+
+class BugBearVisitor(ast.NodeVisitor):
+    """
+    Class implementing a node visitor to check for various topics.
+    """
+    #
+    # This class was implemented along the BugBear flake8 extension (v 18.2.0).
+    # Original: Copyright (c) 2016 Ɓukasz Langa
+    #
+    
+    NodeWindowSize = 4
+    
+    def __init__(self):
+        """
+        Constructor
+        """
+        super(BugBearVisitor, self).__init__()
+        
+        self.__nodeStack = []
+        self.__nodeWindow = []
+        self.violations = []
+    
+    def visit(self, node):
+        """
+        Public method to traverse a given AST node.
+        
+        @param node AST node to be traversed
+        @type ast.Node
+        """
+        self.__nodeStack.append(node)
+        self.__nodeWindow.append(node)
+        self.__nodeWindow = \
+            self.__nodeWindow[-BugBearVisitor.NodeWindowSize:]
+        
+        super(BugBearVisitor, self).visit(node)
+        
+        self.__nodeStack.pop()
+    
+    def visit_UAdd(self, node):
+        """
+        Public method to handle unary additions.
+        
+        @param node reference to the node to be processed
+        @type ast.UAdd
+        """
+        trailingNodes = list(map(type, self.__nodeWindow[-4:]))
+        if trailingNodes == [ast.UnaryOp, ast.UAdd, ast.UnaryOp, ast.UAdd]:
+            originator = self.__nodeWindow[-4]
+            self.violations.append((originator, "M501"))
+        
+        self.generic_visit(node)
+    
+    def visit_Call(self, node):
+        """
+        Public method to handle a function call.
+        
+        @param node reference to the node to be processed
+        @type ast.Call
+        """
+        if sys.version_info >= (3, 0):
+            validPaths = ("six", "future.utils", "builtins")
+            methodsDict = {
+                "M511": ("iterkeys", "itervalues", "iteritems", "iterlists"),
+                "M512": ("viewkeys", "viewvalues", "viewitems", "viewlists"),
+                "M513": ("next",),
+            }
+        else:
+            validPaths = ()
+            methodsDict = {}
+        
+        if isinstance(node.func, ast.Attribute):
+            for code, methods in methodsDict.items():
+                if node.func.attr in methods:
+                    callPath = ".".join(composeCallPath(node.func.value))
+                    if callPath not in validPaths:
+                        self.violations.append((node, code))
+                    break
+            else:
+                self.__checkForM502(node)
+        else:
+            try:
+                if (
+                    node.func.id in ("getattr", "hasattr") and
+                    node.args[1].s == "__call__"
+                ):
+                    self.violations.append((node, "M503"))
+            except (AttributeError, IndexError):
+                pass
+
+            self.generic_visit(node)
+    
+    def visit_Attribute(self, node):
+        """
+        Public method to handle attributes.
+        
+        @param node reference to the node to be processed
+        @type ast.Attribute
+        """
+        callPath = list(composeCallPath(node))
+        
+        if '.'.join(callPath) == 'sys.maxint' and sys.version_info >= (3, 0):
+            self.violations.append((node, "M504"))
+        
+        elif len(callPath) == 2 and callPath[1] == 'message' and \
+                sys.version_info >= (2, 6):
+            name = callPath[0]
+            for elem in reversed(self.__nodeStack[:-1]):
+                if isinstance(elem, ast.ExceptHandler) and elem.name == name:
+                    self.violations.append((node, "M505"))
+                    break
+    
+    def __checkForM502(self, node):
+        """
+        Private method to check the use of *strip().
+        
+        @param node reference to the node to be processed
+        @type ast.Call
+        """
+        if node.func.attr not in ("lstrip", "rstrip", "strip"):
+            return          # method name doesn't match
+        
+        if len(node.args) != 1 or not isinstance(node.args[0], ast.Str):
+            return          # used arguments don't match the builtin strip
+        
+        s = node.args[0].s
+        if len(s) == 1:
+            return          # stripping just one character
+        
+        if len(s) == len(set(s)):
+            return          # no characters appear more than once
+
+        self.violations.append((node, "M502"))
+
 #
 # eflag: noqa = M702

eric ide

mercurial