eric6/Plugins/CheckerPlugins/CodeStyleChecker/MiscellaneousChecker.py

changeset 6942
2602857055c5
parent 6889
334257ef9435
child 7021
2894aa889a4e
diff -r f99d60d6b59b -r 2602857055c5 eric6/Plugins/CheckerPlugins/CodeStyleChecker/MiscellaneousChecker.py
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/eric6/Plugins/CheckerPlugins/CodeStyleChecker/MiscellaneousChecker.py	Sun Apr 14 15:09:21 2019 +0200
@@ -0,0 +1,1606 @@
+# -*- coding: utf-8 -*-
+
+# Copyright (c) 2015 - 2019 Detlev Offenbach <detlev@die-offenbachs.de>
+#
+
+"""
+Module implementing a checker for miscellaneous checks.
+"""
+
+import sys
+import ast
+import re
+import itertools
+from string import Formatter
+from collections import defaultdict
+
+
+def composeCallPath(node):
+    """
+    Generator function to assemble the call path of a given node.
+    
+    @param node node to assemble call path for
+    @type ast.Node
+    @return call path components
+    @rtype str
+    """
+    if isinstance(node, ast.Attribute):
+        for v in composeCallPath(node.value):
+            yield v
+        yield node.attr
+    elif isinstance(node, ast.Name):
+        yield node.id
+
+
+class MiscellaneousChecker(object):
+    """
+    Class implementing a checker for miscellaneous checks.
+    """
+    Codes = [
+        "M101", "M102",
+        "M111", "M112",
+        "M131", "M132",
+        
+        "M191", "M192", "M193", "M194",
+        "M195", "M196", "M197", "M198",
+        
+        "M201",
+        
+        "M501", "M502", "M503", "M504", "M505", "M506", "M507",
+        "M511", "M512", "M513", "M514",
+        
+        "M601",
+        "M611", "M612", "M613",
+        "M621", "M622", "M623", "M624", "M625",
+        "M631", "M632",
+        "M651", "M652", "M653", "M654", "M655",
+        
+        "M701", "M702",
+        "M711",
+        
+        "M801",
+        "M811",
+        "M821", "M822",
+        "M831", "M832", "M833", "M834",
+        
+        "M901",
+    ]
+    
+    Formatter = Formatter()
+    FormatFieldRegex = re.compile(r'^((?:\s|.)*?)(\..*|\[.*\])?$')
+    
+    BuiltinsWhiteList = [
+        "__name__",
+        "__doc__",
+        "credits",
+    ]
+
+    def __init__(self, source, filename, select, ignore, expected, repeat,
+                 args):
+        """
+        Constructor
+        
+        @param source source code to be checked
+        @type list of str
+        @param filename name of the source file
+        @type str
+        @param select list of selected codes
+        @type list of str
+        @param ignore list of codes to be ignored
+        @type list of str
+        @param expected list of expected codes
+        @type list of str
+        @param repeat flag indicating to report each occurrence of a code
+        @type bool
+        @param args dictionary of arguments for the miscellaneous checks
+        @type dict
+        """
+        self.__select = tuple(select)
+        self.__ignore = ('',) if select else tuple(ignore)
+        self.__expected = expected[:]
+        self.__repeat = repeat
+        self.__filename = filename
+        self.__source = source[:]
+        self.__args = args
+        
+        self.__pep3101FormatRegex = re.compile(
+            r'^(?:[^\'"]*[\'"][^\'"]*[\'"])*\s*%|^\s*%')
+        
+        if sys.version_info >= (3, 0):
+            import builtins
+            self.__builtins = [b for b in dir(builtins)
+                               if b not in self.BuiltinsWhiteList]
+        else:
+            import __builtin__
+            self.__builtins = [b for b in dir(__builtin__)
+                               if b not in self.BuiltinsWhiteList]
+
+        # statistics counters
+        self.counters = {}
+        
+        # collection of detected errors
+        self.errors = []
+        
+        checkersWithCodes = [
+            (self.__checkCoding, ("M101", "M102")),
+            (self.__checkCopyright, ("M111", "M112")),
+            (self.__checkBuiltins, ("M131", "M132")),
+            (self.__checkComprehensions, ("M191", "M192", "M193", "M194",
+                                          "M195", "M196", "M197", "M198")),
+            (self.__checkDictWithSortedKeys, ("M201",)),
+            (self.__checkPep3101, ("M601",)),
+            (self.__checkFormatString, ("M611", "M612", "M613",
+                                        "M621", "M622", "M623", "M624", "M625",
+                                        "M631", "M632")),
+            (self.__checkBugBear, ("M501", "M502", "M503", "M504", "M505",
+                                   "M506", "M507",
+                                   "M511", "M512", "M513", "M514")),
+            (self.__checkLogging, ("M651", "M652", "M653", "M654", "M655")),
+            (self.__checkFuture, ("M701", "M702")),
+            (self.__checkGettext, ("M711",)),
+            (self.__checkPrintStatements, ("M801",)),
+            (self.__checkTuple, ("M811", )),
+            (self.__checkMutableDefault, ("M821", "M822")),
+            (self.__checkReturn, ("M831", "M832", "M833", "M834")),
+        ]
+        
+        self.__defaultArgs = {
+            "BuiltinsChecker": {
+                "chr": ["unichr", ],
+                "str": ["unicode", ],
+            },
+            "CodingChecker": 'latin-1, utf-8',
+            "CopyrightChecker": {
+                "Author": "",
+                "MinFilesize": 0,
+            },
+        }
+        
+        self.__checkers = []
+        for checker, codes in checkersWithCodes:
+            if any(not (code and self.__ignoreCode(code))
+                    for code in codes):
+                self.__checkers.append(checker)
+    
+    def __ignoreCode(self, code):
+        """
+        Private method to check if the message code should be ignored.
+
+        @param code message code to check for
+        @type str
+        @return flag indicating to ignore the given code
+        @rtype bool
+        """
+        return (code.startswith(self.__ignore) and
+                not code.startswith(self.__select))
+    
+    def __error(self, lineNumber, offset, code, *args):
+        """
+        Private method to record an issue.
+        
+        @param lineNumber line number of the issue
+        @type int
+        @param offset position within line of the issue
+        @type int
+        @param code message code
+        @type str
+        @param args arguments for the message
+        @type list
+        """
+        if self.__ignoreCode(code):
+            return
+        
+        if code in self.counters:
+            self.counters[code] += 1
+        else:
+            self.counters[code] = 1
+        
+        # Don't care about expected codes
+        if code in self.__expected:
+            return
+        
+        if code and (self.counters[code] == 1 or self.__repeat):
+            # record the issue with one based line number
+            self.errors.append(
+                (self.__filename, lineNumber + 1, offset, (code, args)))
+    
+    def __reportInvalidSyntax(self):
+        """
+        Private method to report a syntax error.
+        """
+        exc_type, exc = sys.exc_info()[:2]
+        if len(exc.args) > 1:
+            offset = exc.args[1]
+            if len(offset) > 2:
+                offset = offset[1:3]
+        else:
+            offset = (1, 0)
+        self.__error(offset[0] - 1, offset[1] or 0,
+                     'M901', exc_type.__name__, exc.args[0])
+    
+    def run(self):
+        """
+        Public method to check the given source against miscellaneous
+        conditions.
+        """
+        if not self.__filename:
+            # don't do anything, if essential data is missing
+            return
+        
+        if not self.__checkers:
+            # don't do anything, if no codes were selected
+            return
+        
+        source = "".join(self.__source)
+        # Check type for py2: if not str it's unicode
+        if sys.version_info[0] == 2:
+            try:
+                source = source.encode('utf-8')
+            except UnicodeError:
+                pass
+        try:
+            self.__tree = compile(source, self.__filename, 'exec',
+                                  ast.PyCF_ONLY_AST)
+        except (SyntaxError, TypeError):
+            self.__reportInvalidSyntax()
+            return
+        
+        for check in self.__checkers:
+            check()
+    
+    def __getCoding(self):
+        """
+        Private method to get the defined coding of the source.
+        
+        @return tuple containing the line number and the coding
+        @rtype tuple of int and str
+        """
+        for lineno, line in enumerate(self.__source[:5]):
+            matched = re.search(r'coding[:=]\s*([-\w_.]+)',
+                                line, re.IGNORECASE)
+            if matched:
+                return lineno, matched.group(1)
+        else:
+            return 0, ""
+    
+    def __checkCoding(self):
+        """
+        Private method to check the presence of a coding line and valid
+        encodings.
+        """
+        if len(self.__source) == 0:
+            return
+        
+        encodings = [e.lower().strip()
+                     for e in self.__args.get(
+                     "CodingChecker", self.__defaultArgs["CodingChecker"])
+                     .split(",")]
+        lineno, coding = self.__getCoding()
+        if coding:
+            if coding.lower() not in encodings:
+                self.__error(lineno, 0, "M102", coding)
+        else:
+            self.__error(0, 0, "M101")
+    
+    def __checkCopyright(self):
+        """
+        Private method to check the presence of a copyright statement.
+        """
+        source = "".join(self.__source)
+        copyrightArgs = self.__args.get(
+            "CopyrightChecker", self.__defaultArgs["CopyrightChecker"])
+        copyrightMinFileSize = copyrightArgs.get(
+            "MinFilesize",
+            self.__defaultArgs["CopyrightChecker"]["MinFilesize"])
+        copyrightAuthor = copyrightArgs.get(
+            "Author",
+            self.__defaultArgs["CopyrightChecker"]["Author"])
+        copyrightRegexStr = \
+            r"Copyright\s+(\(C\)\s+)?(\d{{4}}\s+-\s+)?\d{{4}}\s+{author}"
+        
+        tocheck = max(1024, copyrightMinFileSize)
+        topOfSource = source[:tocheck]
+        if len(topOfSource) < copyrightMinFileSize:
+            return
+
+        copyrightRe = re.compile(copyrightRegexStr.format(author=r".*"),
+                                 re.IGNORECASE)
+        if not copyrightRe.search(topOfSource):
+            self.__error(0, 0, "M111")
+            return
+        
+        if copyrightAuthor:
+            copyrightAuthorRe = re.compile(
+                copyrightRegexStr.format(author=copyrightAuthor),
+                re.IGNORECASE)
+            if not copyrightAuthorRe.search(topOfSource):
+                self.__error(0, 0, "M112")
+    
+    def __checkPrintStatements(self):
+        """
+        Private method to check for print statements.
+        """
+        for node in ast.walk(self.__tree):
+            if (isinstance(node, ast.Call) and
+                getattr(node.func, 'id', None) == 'print') or \
+               (hasattr(ast, 'Print') and isinstance(node, ast.Print)):
+                self.__error(node.lineno - 1, node.col_offset, "M801")
+    
+    def __checkTuple(self):
+        """
+        Private method to check for one element tuples.
+        """
+        for node in ast.walk(self.__tree):
+            if isinstance(node, ast.Tuple) and \
+                    len(node.elts) == 1:
+                self.__error(node.lineno - 1, node.col_offset, "M811")
+    
+    def __checkFuture(self):
+        """
+        Private method to check the __future__ imports.
+        """
+        expectedImports = {
+            i.strip()
+            for i in self.__args.get("FutureChecker", "").split(",")
+            if bool(i.strip())}
+        if len(expectedImports) == 0:
+            # nothing to check for; disabling the check
+            return
+        
+        imports = set()
+        node = None
+        hasCode = False
+        
+        for node in ast.walk(self.__tree):
+            if (isinstance(node, ast.ImportFrom) and
+                    node.module == '__future__'):
+                imports |= {name.name for name in node.names}
+            elif isinstance(node, ast.Expr):
+                if not isinstance(node.value, ast.Str):
+                    hasCode = True
+                    break
+            elif not isinstance(node, (ast.Module, ast.Str)):
+                hasCode = True
+                break
+
+        if isinstance(node, ast.Module) or not hasCode:
+            return
+
+        if not (imports >= expectedImports):
+            if imports:
+                self.__error(node.lineno - 1, node.col_offset, "M701",
+                             ", ".join(expectedImports), ", ".join(imports))
+            else:
+                self.__error(node.lineno - 1, node.col_offset, "M702",
+                             ", ".join(expectedImports))
+    
+    def __checkPep3101(self):
+        """
+        Private method to check for old style string formatting.
+        """
+        for lineno, line in enumerate(self.__source):
+            match = self.__pep3101FormatRegex.search(line)
+            if match:
+                lineLen = len(line)
+                pos = line.find('%')
+                formatPos = pos
+                formatter = '%'
+                if line[pos + 1] == "(":
+                    pos = line.find(")", pos)
+                c = line[pos]
+                while c not in "diouxXeEfFgGcrs":
+                    pos += 1
+                    if pos >= lineLen:
+                        break
+                    c = line[pos]
+                if c in "diouxXeEfFgGcrs":
+                    formatter += c
+                self.__error(lineno, formatPos, "M601", formatter)
+    
+    def __checkFormatString(self):
+        """
+        Private method to check string format strings.
+        """
+        coding = self.__getCoding()[1]
+        if not coding:
+            # default to utf-8
+            coding = "utf-8"
+        
+        visitor = TextVisitor()
+        visitor.visit(self.__tree)
+        for node in visitor.nodes:
+            text = node.s
+            if sys.version_info[0] > 2 and isinstance(text, bytes):
+                try:
+                    text = text.decode(coding)
+                except UnicodeDecodeError:
+                    continue
+            fields, implicit, explicit = self.__getFields(text)
+            if implicit:
+                if node in visitor.calls:
+                    self.__error(node.lineno - 1, node.col_offset, "M611")
+                else:
+                    if node.is_docstring:
+                        self.__error(node.lineno - 1, node.col_offset, "M612")
+                    else:
+                        self.__error(node.lineno - 1, node.col_offset, "M613")
+            
+            if node in visitor.calls:
+                call, strArgs = visitor.calls[node]
+                
+                numbers = set()
+                names = set()
+                # Determine which fields require a keyword and which an arg
+                for name in fields:
+                    fieldMatch = self.FormatFieldRegex.match(name)
+                    try:
+                        number = int(fieldMatch.group(1))
+                    except ValueError:
+                        number = -1
+                    # negative numbers are considered keywords
+                    if number >= 0:
+                        numbers.add(number)
+                    else:
+                        names.add(fieldMatch.group(1))
+                
+                keywords = {keyword.arg for keyword in call.keywords}
+                numArgs = len(call.args)
+                if strArgs:
+                    numArgs -= 1
+                if sys.version_info < (3, 5):
+                    hasKwArgs = bool(call.kwargs)
+                    hasStarArgs = bool(call.starargs)
+                else:
+                    hasKwArgs = any(kw.arg is None for kw in call.keywords)
+                    hasStarArgs = sum(1 for arg in call.args
+                                      if isinstance(arg, ast.Starred))
+                    
+                    if hasKwArgs:
+                        keywords.discard(None)
+                    if hasStarArgs:
+                        numArgs -= 1
+                
+                # if starargs or kwargs is not None, it can't count the
+                # parameters but at least check if the args are used
+                if hasKwArgs:
+                    if not names:
+                        # No names but kwargs
+                        self.__error(call.lineno - 1, call.col_offset, "M623")
+                if hasStarArgs:
+                    if not numbers:
+                        # No numbers but args
+                        self.__error(call.lineno - 1, call.col_offset, "M624")
+                
+                if not hasKwArgs and not hasStarArgs:
+                    # can actually verify numbers and names
+                    for number in sorted(numbers):
+                        if number >= numArgs:
+                            self.__error(call.lineno - 1, call.col_offset,
+                                         "M621", number)
+                    
+                    for name in sorted(names):
+                        if name not in keywords:
+                            self.__error(call.lineno - 1, call.col_offset,
+                                         "M622", name)
+                
+                for arg in range(numArgs):
+                    if arg not in numbers:
+                        self.__error(call.lineno - 1, call.col_offset, "M631",
+                                     arg)
+                
+                for keyword in keywords:
+                    if keyword not in names:
+                        self.__error(call.lineno - 1, call.col_offset, "M632",
+                                     keyword)
+                
+                if implicit and explicit:
+                    self.__error(call.lineno - 1, call.col_offset, "M625")
+    
+    def __getFields(self, string):
+        """
+        Private method to extract the format field information.
+        
+        @param string format string to be parsed
+        @type str
+        @return format field information as a tuple with fields, implicit
+            field definitions present and explicit field definitions present
+        @rtype tuple of set of str, bool, bool
+        """
+        fields = set()
+        cnt = itertools.count()
+        implicit = False
+        explicit = False
+        try:
+            for _literal, field, spec, conv in self.Formatter.parse(string):
+                if field is not None and (conv is None or conv in 'rsa'):
+                    if not field:
+                        field = str(next(cnt))
+                        implicit = True
+                    else:
+                        explicit = True
+                    fields.add(field)
+                    fields.update(parsedSpec[1]
+                                  for parsedSpec in self.Formatter.parse(spec)
+                                  if parsedSpec[1] is not None)
+        except ValueError:
+            return set(), False, False
+        else:
+            return fields, implicit, explicit
+    
+    def __checkBuiltins(self):
+        """
+        Private method to check, if built-ins are shadowed.
+        """
+        functionDefs = [ast.FunctionDef]
+        try:
+            functionDefs.append(ast.AsyncFunctionDef)
+        except AttributeError:
+            pass
+        
+        ignoreBuiltinAssignments = self.__args.get(
+            "BuiltinsChecker", self.__defaultArgs["BuiltinsChecker"])
+        
+        for node in ast.walk(self.__tree):
+            if isinstance(node, ast.Assign):
+                # assign statement
+                for element in node.targets:
+                    if isinstance(element, ast.Name) and \
+                       element.id in self.__builtins:
+                        value = node.value
+                        if isinstance(value, ast.Name) and \
+                           element.id in ignoreBuiltinAssignments and \
+                           value.id in ignoreBuiltinAssignments[element.id]:
+                            # ignore compatibility assignments
+                            continue
+                        self.__error(element.lineno - 1, element.col_offset,
+                                     "M131", element.id)
+                    elif isinstance(element, (ast.Tuple, ast.List)):
+                        for tupleElement in element.elts:
+                            if isinstance(tupleElement, ast.Name) and \
+                               tupleElement.id in self.__builtins:
+                                self.__error(tupleElement.lineno - 1,
+                                             tupleElement.col_offset,
+                                             "M131", tupleElement.id)
+            elif isinstance(node, ast.For):
+                # for loop
+                target = node.target
+                if isinstance(target, ast.Name) and \
+                   target.id in self.__builtins:
+                    self.__error(target.lineno - 1, target.col_offset,
+                                 "M131", target.id)
+                elif isinstance(target, (ast.Tuple, ast.List)):
+                    for element in target.elts:
+                        if isinstance(element, ast.Name) and \
+                           element.id in self.__builtins:
+                            self.__error(element.lineno - 1,
+                                         element.col_offset,
+                                         "M131", element.id)
+            elif any(isinstance(node, functionDef)
+                     for functionDef in functionDefs):
+                # (asynchronous) function definition
+                if sys.version_info >= (3, 0):
+                    for arg in node.args.args:
+                        if isinstance(arg, ast.arg) and \
+                           arg.arg in self.__builtins:
+                            self.__error(arg.lineno - 1, arg.col_offset,
+                                         "M132", arg.arg)
+                else:
+                    for arg in node.args.args:
+                        if isinstance(arg, ast.Name) and \
+                           arg.id in self.__builtins:
+                            self.__error(arg.lineno - 1, arg.col_offset,
+                                         "M132", arg.id)
+    
+    def __checkComprehensions(self):
+        """
+        Private method to check some comprehension related things.
+        """
+        for node in ast.walk(self.__tree):
+            if (isinstance(node, ast.Call) and
+               len(node.args) == 1 and
+               isinstance(node.func, ast.Name)):
+                if (isinstance(node.args[0], ast.GeneratorExp) and
+                        node.func.id in ('list', 'set', 'dict')):
+                    errorCode = {
+                        "dict": "M193",
+                        "list": "M191",
+                        "set": "M192",
+                    }[node.func.id]
+                    self.__error(node.lineno - 1, node.col_offset, errorCode)
+
+                elif (isinstance(node.args[0], ast.ListComp) and
+                      node.func.id in ('set', 'dict')):
+                    errorCode = {
+                        'dict': 'M195',
+                        'set': 'M194',
+                    }[node.func.id]
+                    self.__error(node.lineno - 1, node.col_offset, errorCode)
+
+                elif (isinstance(node.args[0], ast.List) and
+                      node.func.id in ('set', 'dict')):
+                    errorCode = {
+                        'dict': 'M197',
+                        'set': 'M196',
+                    }[node.func.id]
+                    self.__error(node.lineno - 1, node.col_offset, errorCode)
+
+                elif (isinstance(node.args[0], ast.ListComp) and
+                      node.func.id in ('all', 'any', 'frozenset', 'max', 'min',
+                                       'sorted', 'sum', 'tuple',)):
+                    self.__error(node.lineno - 1, node.col_offset, "M198",
+                                 node.func.id)
+    
+    def __checkMutableDefault(self):
+        """
+        Private method to check for use of mutable types as default arguments.
+        """
+        mutableTypes = (
+            ast.Call,
+            ast.Dict,
+            ast.List,
+            ast.Set,
+        )
+        mutableCalls = (
+            "Counter",
+            "OrderedDict",
+            "collections.Counter",
+            "collections.OrderedDict",
+            "collections.defaultdict",
+            "collections.deque",
+            "defaultdict",
+            "deque",
+            "dict",
+            "list",
+            "set",
+        )
+        functionDefs = [ast.FunctionDef]
+        try:
+            functionDefs.append(ast.AsyncFunctionDef)
+        except AttributeError:
+            pass
+        
+        for node in ast.walk(self.__tree):
+            if any(isinstance(node, functionDef)
+                   for functionDef in functionDefs):
+                for default in node.args.defaults:
+                    if any(isinstance(default, mutableType)
+                           for mutableType in mutableTypes):
+                        typeName = type(default).__name__
+                        if isinstance(default, ast.Call):
+                            callPath = '.'.join(composeCallPath(default.func))
+                            if callPath in mutableCalls:
+                                self.__error(default.lineno - 1,
+                                             default.col_offset,
+                                             "M823", callPath + "()")
+                            else:
+                                self.__error(default.lineno - 1,
+                                             default.col_offset,
+                                             "M822", typeName)
+                        else:
+                            self.__error(default.lineno - 1,
+                                         default.col_offset,
+                                         "M821", typeName)
+    
+    def __dictShouldBeChecked(self, node):
+        """
+        Private function to test, if the node should be checked.
+        
+        @param node reference to the AST node
+        @return flag indicating to check the node
+        @rtype bool
+        """
+        if not all(isinstance(key, ast.Str) for key in node.keys):
+            return False
+        
+        if "__IGNORE_WARNING__" in self.__source[node.lineno - 1] or \
+           "__IGNORE_WARNING_M201__" in self.__source[node.lineno - 1]:
+            return False
+        
+        lineNumbers = [key.lineno for key in node.keys]
+        return len(lineNumbers) == len(set(lineNumbers))
+    
+    def __checkDictWithSortedKeys(self):
+        """
+        Private method to check, if dictionary keys appear in sorted order.
+        """
+        for node in ast.walk(self.__tree):
+            if isinstance(node, ast.Dict) and self.__dictShouldBeChecked(node):
+                for key1, key2 in zip(node.keys, node.keys[1:]):
+                    if key2.s < key1.s:
+                        self.__error(key2.lineno - 1, key2.col_offset,
+                                     "M201", key2.s, key1.s)
+    
+    def __checkLogging(self):
+        """
+        Private method to check logging statements.
+        """
+        visitor = LoggingVisitor()
+        visitor.visit(self.__tree)
+        for node, reason in visitor.violations:
+            self.__error(node.lineno - 1, node.col_offset, reason)
+    
+    def __checkGettext(self):
+        """
+        Private method to check the 'gettext' import statement.
+        """
+        for node in ast.walk(self.__tree):
+            if isinstance(node, ast.ImportFrom) and \
+               any(name.asname == '_' for name in node.names):
+                self.__error(node.lineno - 1, node.col_offset, "M711",
+                             node.names[0].name)
+    
+    def __checkBugBear(self):
+        """
+        Private method to bugbear checks.
+        """
+        visitor = BugBearVisitor()
+        visitor.visit(self.__tree)
+        for violation in visitor.violations:
+            node = violation[0]
+            reason = violation[1]
+            params = violation[2:]
+            self.__error(node.lineno - 1, node.col_offset, reason, *params)
+    
+    def __checkReturn(self):
+        """
+        Private method to check return statements.
+        """
+        visitor = ReturnVisitor()
+        visitor.visit(self.__tree)
+        for violation in visitor.violations:
+            node = violation[0]
+            reason = violation[1]
+            self.__error(node.lineno - 1, node.col_offset, reason)
+
+
+class TextVisitor(ast.NodeVisitor):
+    """
+    Class implementing a node visitor for bytes and str instances.
+
+    It tries to detect docstrings as string of the first expression of each
+    module, class or function.
+    """
+    # modelled after the string format flake8 extension
+    
+    def __init__(self):
+        """
+        Constructor
+        """
+        super(TextVisitor, self).__init__()
+        self.nodes = []
+        self.calls = {}
+
+    def __addNode(self, node):
+        """
+        Private method to add a node to our list of nodes.
+        
+        @param node reference to the node to add
+        @type ast.AST
+        """
+        if not hasattr(node, 'is_docstring'):
+            node.is_docstring = False
+        self.nodes.append(node)
+
+    def __isBaseString(self, node):
+        """
+        Private method to determine, if a node is a base string node.
+        
+        @param node reference to the node to check
+        @type ast.AST
+        @return flag indicating a base string
+        @rtype bool
+        """
+        typ = (ast.Str,)
+        if sys.version_info[0] > 2:
+            typ += (ast.Bytes,)
+        return isinstance(node, typ)
+
+    def visit_Str(self, node):
+        """
+        Public method to record a string node.
+        
+        @param node reference to the string node
+        @type ast.Str
+        """
+        self.__addNode(node)
+
+    def visit_Bytes(self, node):
+        """
+        Public method to record a bytes node.
+        
+        @param node reference to the bytes node
+        @type ast.Bytes
+        """
+        self.__addNode(node)
+
+    def __visitDefinition(self, node):
+        """
+        Private method handling class and function definitions.
+        
+        @param node reference to the node to handle
+        @type ast.FunctionDef, ast.AsyncFunctionDef or ast.ClassDef
+        """
+        # Manually traverse class or function definition
+        # * Handle decorators normally
+        # * Use special check for body content
+        # * Don't handle the rest (e.g. bases)
+        for decorator in node.decorator_list:
+            self.visit(decorator)
+        self.__visitBody(node)
+
+    def __visitBody(self, node):
+        """
+        Private method to traverse the body of the node manually.
+
+        If the first node is an expression which contains a string or bytes it
+        marks that as a docstring.
+        
+        @param node reference to the node to traverse
+        @type ast.AST
+        """
+        if (node.body and isinstance(node.body[0], ast.Expr) and
+                self.__isBaseString(node.body[0].value)):
+            node.body[0].value.is_docstring = True
+
+        for subnode in node.body:
+            self.visit(subnode)
+
+    def visit_Module(self, node):
+        """
+        Public method to handle a module.
+        
+        @param node reference to the node to handle
+        @type ast.Module
+        """
+        self.__visitBody(node)
+
+    def visit_ClassDef(self, node):
+        """
+        Public method to handle a class definition.
+        
+        @param node reference to the node to handle
+        @type ast.ClassDef
+        """
+        # Skipped nodes: ('name', 'bases', 'keywords', 'starargs', 'kwargs')
+        self.__visitDefinition(node)
+
+    def visit_FunctionDef(self, node):
+        """
+        Public method to handle a function definition.
+        
+        @param node reference to the node to handle
+        @type ast.FunctionDef
+        """
+        # Skipped nodes: ('name', 'args', 'returns')
+        self.__visitDefinition(node)
+
+    def visit_AsyncFunctionDef(self, node):
+        """
+        Public method to handle an asynchronous function definition.
+        
+        @param node reference to the node to handle
+        @type ast.AsyncFunctionDef
+        """
+        # Skipped nodes: ('name', 'args', 'returns')
+        self.__visitDefinition(node)
+
+    def visit_Call(self, node):
+        """
+        Public method to handle a function call.
+        
+        @param node reference to the node to handle
+        @type ast.Call
+        """
+        if (isinstance(node.func, ast.Attribute) and
+                node.func.attr == 'format'):
+            if self.__isBaseString(node.func.value):
+                self.calls[node.func.value] = (node, False)
+            elif (isinstance(node.func.value, ast.Name) and
+                    node.func.value.id == 'str' and node.args and
+                    self.__isBaseString(node.args[0])):
+                self.calls[node.args[0]] = (node, True)
+        super(TextVisitor, self).generic_visit(node)
+
+
+class LoggingVisitor(ast.NodeVisitor):
+    """
+    Class implementing a node visitor to check logging statements.
+    """
+    LoggingLevels = {
+        "debug",
+        "critical",
+        "error",
+        "info",
+        "warn",
+        "warning",
+    }
+    
+    def __init__(self):
+        """
+        Constructor
+        """
+        super(LoggingVisitor, self).__init__()
+        
+        self.__currentLoggingCall = None
+        self.__currentLoggingArgument = None
+        self.__currentLoggingLevel = None
+        self.__currentExtraKeyword = None
+        self.violations = []
+
+    def __withinLoggingStatement(self):
+        """
+        Private method to check, if we are inside a logging statement.
+        
+        @return flag indicating we are inside a logging statement
+        @rtype bool
+        """
+        return self.__currentLoggingCall is not None
+
+    def __withinLoggingArgument(self):
+        """
+        Private method to check, if we are inside a logging argument.
+        
+        @return flag indicating we are inside a logging argument
+        @rtype bool
+        """
+        return self.__currentLoggingArgument is not None
+
+    def __withinExtraKeyword(self, node):
+        """
+        Private method to check, if we are inside the extra keyword.
+        
+        @param node reference to the node to be checked
+        @type ast.keyword
+        @return flag indicating we are inside the extra keyword
+        @rtype bool
+        """
+        return self.__currentExtraKeyword is not None and \
+            self.__currentExtraKeyword != node
+    
+    def __detectLoggingLevel(self, node):
+        """
+        Private method to decide whether an AST Call is a logging call.
+        
+        @param node reference to the node to be processed
+        @type ast.Call
+        @return logging level
+        @rtype str or None
+        """
+        try:
+            if node.func.value.id == "warnings":
+                return None
+            
+            if node.func.attr in LoggingVisitor.LoggingLevels:
+                return node.func.attr
+        except AttributeError:
+            pass
+        
+        return None
+
+    def __isFormatCall(self, node):
+        """
+        Private method to check if a function call uses format.
+
+        @param node reference to the node to be processed
+        @type ast.Call
+        @return flag indicating the function call uses format
+        @rtype bool
+        """
+        try:
+            return node.func.attr == "format"
+        except AttributeError:
+            return False
+    
+    def visit_Call(self, node):
+        """
+        Public method to handle a function call.
+
+        Every logging statement and string format is expected to be a function
+        call.
+        
+        @param node reference to the node to be processed
+        @type ast.Call
+        """
+        # we are in a logging statement
+        if self.__withinLoggingStatement():
+            if self.__withinLoggingArgument() and self.__isFormatCall(node):
+                self.violations.append((node, "M651"))
+                super(LoggingVisitor, self).generic_visit(node)
+                return
+        
+        loggingLevel = self.__detectLoggingLevel(node)
+        
+        if loggingLevel and self.__currentLoggingLevel is None:
+            self.__currentLoggingLevel = loggingLevel
+        
+        # we are in some other statement
+        if loggingLevel is None:
+            super(LoggingVisitor, self).generic_visit(node)
+            return
+        
+        # we are entering a new logging statement
+        self.__currentLoggingCall = node
+        
+        if loggingLevel == "warn":
+            self.violations.append((node, "M655"))
+        
+        for index, child in enumerate(ast.iter_child_nodes(node)):
+            if index == 1:
+                self.__currentLoggingArgument = child
+            if index > 1 and isinstance(child, ast.keyword) and \
+               child.arg == "extra":
+                self.__currentExtraKeyword = child
+            
+            super(LoggingVisitor, self).visit(child)
+            
+            self.__currentLoggingArgument = None
+            self.__currentExtraKeyword = None
+        
+        self.__currentLoggingCall = None
+        self.__currentLoggingLevel = None
+    
+    def visit_BinOp(self, node):
+        """
+        Public method to handle binary operations while processing the first
+        logging argument.
+        
+        @param node reference to the node to be processed
+        @type ast.BinOp
+        """
+        if self.__withinLoggingStatement() and self.__withinLoggingArgument():
+            # handle percent format
+            if isinstance(node.op, ast.Mod):
+                self.violations.append((node, "M652"))
+            
+            # handle string concat
+            if isinstance(node.op, ast.Add):
+                self.violations.append((node, "M653"))
+        
+        super(LoggingVisitor, self).generic_visit(node)
+    
+    def visit_JoinedStr(self, node):
+        """
+        Public method to handle f-string arguments.
+        
+        @param node reference to the node to be processed
+        @type ast.JoinedStr
+        """
+        if sys.version_info >= (3, 6):
+            if self.__withinLoggingStatement():
+                if any(isinstance(i, ast.FormattedValue) for i in node.values):
+                    if self.__withinLoggingArgument():
+                        self.violations.append((node, "M654"))
+                        
+                        super(LoggingVisitor, self).generic_visit(node)
+
+
+class BugBearVisitor(ast.NodeVisitor):
+    """
+    Class implementing a node visitor to check for various topics.
+    """
+    #
+    # This class was implemented along the BugBear flake8 extension (v 18.2.0).
+    # Original: Copyright (c) 2016 Ɓukasz Langa
+    #
+    
+    NodeWindowSize = 4
+    
+    def __init__(self):
+        """
+        Constructor
+        """
+        super(BugBearVisitor, self).__init__()
+        
+        self.__nodeStack = []
+        self.__nodeWindow = []
+        self.violations = []
+    
+    def visit(self, node):
+        """
+        Public method to traverse a given AST node.
+        
+        @param node AST node to be traversed
+        @type ast.Node
+        """
+        self.__nodeStack.append(node)
+        self.__nodeWindow.append(node)
+        self.__nodeWindow = \
+            self.__nodeWindow[-BugBearVisitor.NodeWindowSize:]
+        
+        super(BugBearVisitor, self).visit(node)
+        
+        self.__nodeStack.pop()
+    
+    def visit_UAdd(self, node):
+        """
+        Public method to handle unary additions.
+        
+        @param node reference to the node to be processed
+        @type ast.UAdd
+        """
+        trailingNodes = list(map(type, self.__nodeWindow[-4:]))
+        if trailingNodes == [ast.UnaryOp, ast.UAdd, ast.UnaryOp, ast.UAdd]:
+            originator = self.__nodeWindow[-4]
+            self.violations.append((originator, "M501"))
+        
+        self.generic_visit(node)
+    
+    def visit_Call(self, node):
+        """
+        Public method to handle a function call.
+        
+        @param node reference to the node to be processed
+        @type ast.Call
+        """
+        if sys.version_info >= (3, 0):
+            validPaths = ("six", "future.utils", "builtins")
+            methodsDict = {
+                "M511": ("iterkeys", "itervalues", "iteritems", "iterlists"),
+                "M512": ("viewkeys", "viewvalues", "viewitems", "viewlists"),
+                "M513": ("next",),
+            }
+        else:
+            validPaths = ()
+            methodsDict = {}
+        
+        if isinstance(node.func, ast.Attribute):
+            for code, methods in methodsDict.items():
+                if node.func.attr in methods:
+                    callPath = ".".join(composeCallPath(node.func.value))
+                    if callPath not in validPaths:
+                        self.violations.append((node, code))
+                    break
+            else:
+                self.__checkForM502(node)
+        else:
+            try:
+                if (
+                    node.func.id in ("getattr", "hasattr") and
+                    node.args[1].s == "__call__"
+                ):
+                    self.violations.append((node, "M503"))
+            except (AttributeError, IndexError):
+                pass
+
+            self.generic_visit(node)
+    
+    def visit_Attribute(self, node):
+        """
+        Public method to handle attributes.
+        
+        @param node reference to the node to be processed
+        @type ast.Attribute
+        """
+        callPath = list(composeCallPath(node))
+        
+        if '.'.join(callPath) == 'sys.maxint' and sys.version_info >= (3, 0):
+            self.violations.append((node, "M504"))
+        
+        elif len(callPath) == 2 and callPath[1] == 'message' and \
+                sys.version_info >= (2, 6):
+            name = callPath[0]
+            for elem in reversed(self.__nodeStack[:-1]):
+                if isinstance(elem, ast.ExceptHandler) and elem.name == name:
+                    self.violations.append((node, "M505"))
+                    break
+    
+    def visit_Assign(self, node):
+        """
+        Public method to handle assignments.
+        
+        @param node reference to the node to be processed
+        @type ast.Assign
+        """
+        if isinstance(self.__nodeStack[-2], ast.ClassDef):
+            # By using 'hasattr' below we're ignoring starred arguments, slices
+            # and tuples for simplicity.
+            assignTargets = {t.id for t in node.targets if hasattr(t, 'id')}
+            if '__metaclass__' in assignTargets and sys.version_info >= (3, 0):
+                self.violations.append((node, "M514"))
+        
+        elif len(node.targets) == 1:
+            target = node.targets[0]
+            if isinstance(target, ast.Attribute) and \
+               isinstance(target.value, ast.Name):
+                if (target.value.id, target.attr) == ('os', 'environ'):
+                    self.violations.append((node, "M506"))
+        
+        self.generic_visit(node)
+    
+    def visit_For(self, node):
+        """
+        Public method to handle 'for' statements.
+        
+        @param node reference to the node to be processed
+        @type ast.For
+        """
+        self.__checkForM507(node)
+        
+        self.generic_visit(node)
+    
+    def __checkForM502(self, node):
+        """
+        Private method to check the use of *strip().
+        
+        @param node reference to the node to be processed
+        @type ast.Call
+        """
+        if node.func.attr not in ("lstrip", "rstrip", "strip"):
+            return          # method name doesn't match
+        
+        if len(node.args) != 1 or not isinstance(node.args[0], ast.Str):
+            return          # used arguments don't match the builtin strip
+        
+        s = node.args[0].s
+        if len(s) == 1:
+            return          # stripping just one character
+        
+        if len(s) == len(set(s)):
+            return          # no characters appear more than once
+
+        self.violations.append((node, "M502"))
+    
+    def __checkForM507(self, node):
+        """
+        Private method to check for unused loop variables.
+        
+        @param node reference to the node to be processed
+        @type ast.For
+        """
+        targets = NameFinder()
+        targets.visit(node.target)
+        ctrlNames = set(filter(lambda s: not s.startswith('_'),
+                               targets.getNames()))
+        body = NameFinder()
+        for expr in node.body:
+            body.visit(expr)
+        usedNames = set(body.getNames())
+        for name in sorted(ctrlNames - usedNames):
+            n = targets.getNames()[name][0]
+            self.violations.append((n, "M507", name))
+
+
+class NameFinder(ast.NodeVisitor):
+    """
+    Class to extract a name out of a tree of nodes.
+    """
+    def __init__(self):
+        """
+        Constructor
+        """
+        super(NameFinder, self).__init__()
+        
+        self.__names = {}
+
+    def visit_Name(self, node):
+        """
+        Public method to handle 'Name' nodes.
+        
+        @param node reference to the node to be processed
+        @type ast.Name
+        """
+        self.__names.setdefault(node.id, []).append(node)
+
+    def visit(self, node):
+        """
+        Public method to traverse a given AST node.
+        
+        @param node AST node to be traversed
+        @type ast.Node
+        """
+        if isinstance(node, list):
+            for elem in node:
+                super(NameFinder, self).visit(elem)
+        else:
+            super(NameFinder, self).visit(node)
+    
+    def getNames(self):
+        """
+        Public method to return the extracted names and Name nodes.
+        
+        @return dictionary containing the names as keys and the list of nodes
+        @rtype dict
+        """
+        return self.__names
+
+
+class ReturnVisitor(ast.NodeVisitor):
+    """
+    Class implementing a node visitor to check return statements.
+    """
+    Assigns = 'assigns'
+    Refs = 'refs'
+    Returns = 'returns'
+    
+    def __init__(self):
+        """
+        Constructor
+        """
+        super(ReturnVisitor, self).__init__()
+        
+        self.__stack = []
+        self.violations = []
+    
+    @property
+    def assigns(self):
+        """
+        Public method to get the Assign nodes.
+        
+        @return dictionary containing the node name as key and line number
+            as value
+        @rtype dict
+        """
+        return self.__stack[-1][ReturnVisitor.Assigns]
+    
+    @property
+    def refs(self):
+        """
+        Public method to get the References nodes.
+        
+        @return dictionary containing the node name as key and line number
+            as value
+        @rtype dict
+        """
+        return self.__stack[-1][ReturnVisitor.Refs]
+    
+    @property
+    def returns(self):
+        """
+        Public method to get the Return nodes.
+        
+        @return dictionary containing the node name as key and line number
+            as value
+        @rtype dict
+        """
+        return self.__stack[-1][ReturnVisitor.Returns]
+    
+    def __visitWithStack(self, node):
+        """
+        Private method to traverse a given function node using a stack.
+        
+        @param node AST node to be traversed
+        @type ast.FunctionDef or ast.AsyncFunctionDef
+        """
+        self.__stack.append({
+            ReturnVisitor.Assigns: defaultdict(list),
+            ReturnVisitor.Refs: defaultdict(list),
+            ReturnVisitor.Returns: []
+        })
+        
+        self.generic_visit(node)
+        self.__checkFunction(node)
+        self.__stack.pop()
+    
+    def visit_FunctionDef(self, node):
+        """
+        Public method to handle a function definition.
+        
+        @param node reference to the node to handle
+        @type ast.FunctionDef
+        """
+        self.__visitWithStack(node)
+    
+    def visit_AsyncFunctionDef(self, node):
+        """
+        Public method to handle a function definition.
+        
+        @param node reference to the node to handle
+        @type ast.AsyncFunctionDef
+        """
+        self.__visitWithStack(node)
+    
+    def visit_Return(self, node):
+        """
+        Public method to handle a return node.
+        
+        @param node reference to the node to handle
+        @type ast.Return
+        """
+        self.returns.append(node)
+        self.generic_visit(node)
+    
+    def visit_Assign(self, node):
+        """
+        Public method to handle an assign node.
+        
+        @param node reference to the node to handle
+        @type ast.Assign
+        """
+        if not self.__stack:
+            return
+        
+        for target in node.targets:
+            self.__visitAssignTarget(target)
+        self.generic_visit(node.value)
+    
+    def visit_Name(self, node):
+        """
+        Public method to handle a name node.
+        
+        @param node reference to the node to handle
+        @type ast.Name
+        """
+        if self.__stack:
+            self.refs[node.id].append(node.lineno)
+    
+    def __visitAssignTarget(self, node):
+        """
+        Private method to handle an assign target node.
+        
+        @param node reference to the node to handle
+        @type ast.AST
+        """
+        if isinstance(node, ast.Tuple):
+            for elt in node.elts:
+                self.__visitAssignTarget(elt)
+            return
+        
+        if isinstance(node, ast.Name):
+            self.assigns[node.id].append(node.lineno)
+            return
+        
+        self.generic_visit(node)
+    
+    def __checkFunction(self, node):
+        """
+        Private method to check a function definition node.
+        
+        @param node reference to the node to check
+        @type ast.AsyncFunctionDef or ast.FunctionDef
+        """
+        if not self.returns or not node.body:
+            return
+        
+        if len(node.body) == 1 and isinstance(node.body[-1], ast.Return):
+            # skip functions that consist of `return None` only
+            return
+        
+        if not self.__resultExists():
+            self.__checkUnnecessaryReturnNone()
+            return
+        
+        self.__checkImplicitReturnValue()
+        self.__checkImplicitReturn(node.body[-1])
+        
+        for n in self.returns:
+            if n.value:
+                self.__checkUnnecessaryAssign(n.value)
+    
+    def __isNone(self, node):
+        """
+        Private method to check, if a node value is None.
+        
+        @param node reference to the node to check
+        @type ast.AST
+        @return flag indicating the node contains a None value
+        """
+        try:
+            return isinstance(node, ast.NameConstant) and node.value is None
+        except AttributeError:
+            # try Py2
+            return isinstance(node, ast.Name) and node.id == "None"
+    
+    def __resultExists(self):
+        """
+        Private method to check the existance of a return result.
+        
+        @return flag indicating the existence of a return result
+        @rtype bool
+        """
+        for node in self.returns:
+            value = node.value
+            if value and not self.__isNone(value):
+                return True
+        
+        return False
+    
+    def __checkImplicitReturnValue(self):
+        """
+        Private method to check for implicit return values.
+        """
+        for node in self.returns:
+            if not node.value:
+                self.violations.append((node, "M832"))
+    
+    def __checkUnnecessaryReturnNone(self):
+        """
+        Private method to check for an unnecessary 'return None' statement.
+        """
+        for node in self.returns:
+            if self.__isNone(node.value):
+                self.violations.append((node, "M831"))
+    
+    def __checkImplicitReturn(self, node):
+        """
+        Private method to check for an implicit return statement.
+        
+        @param node reference to the node to check
+        @type ast.AST
+        """
+        if isinstance(node, ast.If):
+            if not node.body or not node.orelse:
+                self.violations.append((node, "M833"))
+                return
+            
+            self.__checkImplicitReturn(node.body[-1])
+            self.__checkImplicitReturn(node.orelse[-1])
+            return
+        
+        if isinstance(node, ast.For) and node.orelse:
+            self.__checkImplicitReturn(node.orelse[-1])
+            return
+        
+        if isinstance(node, ast.With):
+            self.__checkImplicitReturn(node.body[-1])
+            return
+        
+        try:
+            okNodes = (ast.Return, ast.Raise, ast.While, ast.Try)
+        except AttributeError:
+            # Py2
+            okNodes = (ast.Return, ast.Raise, ast.While)
+        if not isinstance(node, okNodes):
+            self.violations.append((node, "M833"))
+    
+    def __checkUnnecessaryAssign(self, node):
+        """
+        Private method to check for an unnecessary assign statement.
+        
+        @param node reference to the node to check
+        @type ast.AST
+        """
+        if not isinstance(node, ast.Name):
+            return
+        
+        varname = node.id
+        returnLineno = node.lineno
+        
+        if varname not in self.assigns:
+            return
+        
+        if varname not in self.refs:
+            self.violations.append((node, "M834"))
+            return
+        
+        if self.__hasRefsBeforeNextAssign(varname, returnLineno):
+            return
+        
+        self.violations.append((node, "M834"))
+
+    def __hasRefsBeforeNextAssign(self, varname, returnLineno):
+        """
+        Private method to check for references before a following assign
+        statement.
+        
+        @param varname variable name to check for
+        @type str
+        @param returnLineno line number of the return statement
+        @type int
+        @return flag indicating the existence of references
+        @rtype bool
+        """
+        beforeAssign = 0
+        afterAssign = None
+        
+        for lineno in sorted(self.assigns[varname]):
+            if lineno > returnLineno:
+                afterAssign = lineno
+                break
+            
+            if lineno <= returnLineno:
+                beforeAssign = lineno
+        
+        for lineno in self.refs[varname]:
+            if lineno == returnLineno:
+                continue
+            
+            if afterAssign:
+                if beforeAssign < lineno <= afterAssign:
+                    return True
+            
+            elif beforeAssign < lineno:
+                return True
+        
+        return False
+#
+# eflag: noqa = M702

eric ide

mercurial