eric6/Plugins/CheckerPlugins/CodeStyleChecker/MiscellaneousChecker.py

Wed, 26 Jun 2019 19:41:11 +0200

author
Detlev Offenbach <detlev@die-offenbachs.de>
date
Wed, 26 Jun 2019 19:41:11 +0200
changeset 7040
f89952e5fc11
parent 7021
2894aa889a4e
child 7042
2be5b245e1b8
permissions
-rw-r--r--

Code Style Checker: added check for commented code that should be removed.

# -*- coding: utf-8 -*-

# Copyright (c) 2015 - 2019 Detlev Offenbach <detlev@die-offenbachs.de>
#

"""
Module implementing a checker for miscellaneous checks.
"""

import sys
import ast
import re
import itertools
from string import Formatter
from collections import defaultdict


def composeCallPath(node):
    """
    Generator function to assemble the call path of a given node.
    
    @param node node to assemble call path for
    @type ast.Node
    @return call path components
    @rtype str
    """
    if isinstance(node, ast.Attribute):
        for v in composeCallPath(node.value):
            yield v
        yield node.attr
    elif isinstance(node, ast.Name):
        yield node.id


class MiscellaneousChecker(object):
    """
    Class implementing a checker for miscellaneous checks.
    """
    Codes = [
        "M101", "M102",
        "M111", "M112",
        "M131", "M132",
        
        "M191", "M192", "M193", "M194",
        "M195", "M196", "M197", "M198",
        
        "M201",
        
        "M501", "M502", "M503", "M504", "M505", "M506", "M507", "M508",
        "M509",
        "M511", "M512", "M513",
        "M521", "M522", "M523", "M524",
        
        "M601",
        "M611", "M612", "M613",
        "M621", "M622", "M623", "M624", "M625",
        "M631", "M632",
        "M651", "M652", "M653", "M654", "M655",
        
        "M701", "M702",
        "M711",
        
        "M801",
        "M811",
        "M821", "M822",
        "M831", "M832", "M833", "M834",
        
        "M891",
        
        "M901",
    ]
    
    Formatter = Formatter()
    FormatFieldRegex = re.compile(r'^((?:\s|.)*?)(\..*|\[.*\])?$')
    
    BuiltinsWhiteList = [
        "__name__",
        "__doc__",
        "credits",
    ]

    def __init__(self, source, filename, select, ignore, expected, repeat,
                 args):
        """
        Constructor
        
        @param source source code to be checked
        @type list of str
        @param filename name of the source file
        @type str
        @param select list of selected codes
        @type list of str
        @param ignore list of codes to be ignored
        @type list of str
        @param expected list of expected codes
        @type list of str
        @param repeat flag indicating to report each occurrence of a code
        @type bool
        @param args dictionary of arguments for the miscellaneous checks
        @type dict
        """
        self.__select = tuple(select)
        self.__ignore = ('',) if select else tuple(ignore)
        self.__expected = expected[:]
        self.__repeat = repeat
        self.__filename = filename
        self.__source = source[:]
        self.__args = args
        
        self.__pep3101FormatRegex = re.compile(
            r'^(?:[^\'"]*[\'"][^\'"]*[\'"])*\s*%|^\s*%')
        
        if sys.version_info >= (3, 0):
            import builtins
            self.__builtins = [b for b in dir(builtins)
                               if b not in self.BuiltinsWhiteList]
        else:
            import __builtin__
            self.__builtins = [b for b in dir(__builtin__)
                               if b not in self.BuiltinsWhiteList]

        # statistics counters
        self.counters = {}
        
        # collection of detected errors
        self.errors = []
        
        checkersWithCodes = [
            (self.__checkCoding, ("M101", "M102")),
            (self.__checkCopyright, ("M111", "M112")),
            (self.__checkBuiltins, ("M131", "M132")),
            (self.__checkComprehensions, ("M191", "M192", "M193", "M194",
                                          "M195", "M196", "M197", "M198")),
            (self.__checkDictWithSortedKeys, ("M201",)),
            (self.__checkPep3101, ("M601",)),
            (self.__checkFormatString, ("M611", "M612", "M613",
                                        "M621", "M622", "M623", "M624", "M625",
                                        "M631", "M632")),
            (self.__checkBugBear, ("M501", "M502", "M503", "M504", "M505",
                                   "M506", "M507", "M508", "M509",
                                   "M511", "M512", "M513",
                                   "M521", "M522", "M523", "M524")),
            (self.__checkLogging, ("M651", "M652", "M653", "M654", "M655")),
            (self.__checkFuture, ("M701", "M702")),
            (self.__checkGettext, ("M711",)),
            (self.__checkPrintStatements, ("M801",)),
            (self.__checkTuple, ("M811", )),
            (self.__checkMutableDefault, ("M821", "M822")),
            (self.__checkReturn, ("M831", "M832", "M833", "M834")),
            (self.__checkCommentedCode, ("M891")),
        ]
        
        self.__defaultArgs = {
            "BuiltinsChecker": {
                "chr": ["unichr", ],
                "str": ["unicode", ],
            },
            "CodingChecker": 'latin-1, utf-8',
            "CopyrightChecker": {
                "Author": "",
                "MinFilesize": 0,
            },
            "CommentedCodeChecker": {
                "Aggressive": False,
            }
        }
        
        self.__checkers = []
        for checker, codes in checkersWithCodes:
            if any(not (code and self.__ignoreCode(code))
                    for code in codes):
                self.__checkers.append(checker)
    
    def __ignoreCode(self, code):
        """
        Private method to check if the message code should be ignored.

        @param code message code to check for
        @type str
        @return flag indicating to ignore the given code
        @rtype bool
        """
        return (code.startswith(self.__ignore) and
                not code.startswith(self.__select))
    
    def __error(self, lineNumber, offset, code, *args):
        """
        Private method to record an issue.
        
        @param lineNumber line number of the issue
        @type int
        @param offset position within line of the issue
        @type int
        @param code message code
        @type str
        @param args arguments for the message
        @type list
        """
        if self.__ignoreCode(code):
            return
        
        if code in self.counters:
            self.counters[code] += 1
        else:
            self.counters[code] = 1
        
        # Don't care about expected codes
        if code in self.__expected:
            return
        
        if code and (self.counters[code] == 1 or self.__repeat):
            # record the issue with one based line number
            self.errors.append(
                (self.__filename, lineNumber + 1, offset, (code, args)))
    
    def __reportInvalidSyntax(self):
        """
        Private method to report a syntax error.
        """
        exc_type, exc = sys.exc_info()[:2]
        if len(exc.args) > 1:
            offset = exc.args[1]
            if len(offset) > 2:
                offset = offset[1:3]
        else:
            offset = (1, 0)
        self.__error(offset[0] - 1, offset[1] or 0,
                     'M901', exc_type.__name__, exc.args[0])
    
    def run(self):
        """
        Public method to check the given source against miscellaneous
        conditions.
        """
        if not self.__filename:
            # don't do anything, if essential data is missing
            return
        
        if not self.__checkers:
            # don't do anything, if no codes were selected
            return
        
        source = "".join(self.__source)
        # Check type for py2: if not str it's unicode
        if sys.version_info[0] == 2:
            try:
                source = source.encode('utf-8')
            except UnicodeError:
                pass
        try:
            self.__tree = compile(source, self.__filename, 'exec',
                                  ast.PyCF_ONLY_AST)
        except (SyntaxError, TypeError):
            self.__reportInvalidSyntax()
            return
        
        for check in self.__checkers:
            check()
    
    def __getCoding(self):
        """
        Private method to get the defined coding of the source.
        
        @return tuple containing the line number and the coding
        @rtype tuple of int and str
        """
        for lineno, line in enumerate(self.__source[:5]):
            matched = re.search(r'coding[:=]\s*([-\w_.]+)',
                                line, re.IGNORECASE)
            if matched:
                return lineno, matched.group(1)
        else:
            return 0, ""
    
    def __checkCoding(self):
        """
        Private method to check the presence of a coding line and valid
        encodings.
        """
        if len(self.__source) == 0:
            return
        
        encodings = [e.lower().strip()
                     for e in self.__args.get(
                     "CodingChecker", self.__defaultArgs["CodingChecker"])
                     .split(",")]
        lineno, coding = self.__getCoding()
        if coding:
            if coding.lower() not in encodings:
                self.__error(lineno, 0, "M102", coding)
        else:
            self.__error(0, 0, "M101")
    
    def __checkCopyright(self):
        """
        Private method to check the presence of a copyright statement.
        """
        source = "".join(self.__source)
        copyrightArgs = self.__args.get(
            "CopyrightChecker", self.__defaultArgs["CopyrightChecker"])
        copyrightMinFileSize = copyrightArgs.get(
            "MinFilesize",
            self.__defaultArgs["CopyrightChecker"]["MinFilesize"])
        copyrightAuthor = copyrightArgs.get(
            "Author",
            self.__defaultArgs["CopyrightChecker"]["Author"])
        copyrightRegexStr = \
            r"Copyright\s+(\(C\)\s+)?(\d{{4}}\s+-\s+)?\d{{4}}\s+{author}"
        
        tocheck = max(1024, copyrightMinFileSize)
        topOfSource = source[:tocheck]
        if len(topOfSource) < copyrightMinFileSize:
            return

        copyrightRe = re.compile(copyrightRegexStr.format(author=r".*"),
                                 re.IGNORECASE)
        if not copyrightRe.search(topOfSource):
            self.__error(0, 0, "M111")
            return
        
        if copyrightAuthor:
            copyrightAuthorRe = re.compile(
                copyrightRegexStr.format(author=copyrightAuthor),
                re.IGNORECASE)
            if not copyrightAuthorRe.search(topOfSource):
                self.__error(0, 0, "M112")
    
    def __checkCommentedCode(self):
        """
        Private method to check for commented code.
        """
        from eradicate import commented_out_code_line_numbers
        
        source = "".join(self.__source)
        commentedCodeCheckerArgs = self.__args.get(
            "CommentedCodeChecker", self.__defaultArgs["CommentedCodeChecker"])
        aggressive = commentedCodeCheckerArgs.get(
            "Aggressive",
            self.__defaultArgs["CommentedCodeChecker"]["Aggressive"])
        for markedLine in commented_out_code_line_numbers(
                source, aggressive=aggressive):
            self.__error(markedLine - 1, 0, "M891")
    
    def __checkPrintStatements(self):
        """
        Private method to check for print statements.
        """
        for node in ast.walk(self.__tree):
            if (isinstance(node, ast.Call) and
                getattr(node.func, 'id', None) == 'print') or \
               (hasattr(ast, 'Print') and isinstance(node, ast.Print)):
                self.__error(node.lineno - 1, node.col_offset, "M801")
    
    def __checkTuple(self):
        """
        Private method to check for one element tuples.
        """
        for node in ast.walk(self.__tree):
            if isinstance(node, ast.Tuple) and \
                    len(node.elts) == 1:
                self.__error(node.lineno - 1, node.col_offset, "M811")
    
    def __checkFuture(self):
        """
        Private method to check the __future__ imports.
        """
        expectedImports = {
            i.strip()
            for i in self.__args.get("FutureChecker", "").split(",")
            if bool(i.strip())}
        if len(expectedImports) == 0:
            # nothing to check for; disabling the check
            return
        
        imports = set()
        node = None
        hasCode = False
        
        for node in ast.walk(self.__tree):
            if (isinstance(node, ast.ImportFrom) and
                    node.module == '__future__'):
                imports |= {name.name for name in node.names}
            elif isinstance(node, ast.Expr):
                if not isinstance(node.value, ast.Str):
                    hasCode = True
                    break
            elif not isinstance(node, (ast.Module, ast.Str)):
                hasCode = True
                break

        if isinstance(node, ast.Module) or not hasCode:
            return

        if not (imports >= expectedImports):
            if imports:
                self.__error(node.lineno - 1, node.col_offset, "M701",
                             ", ".join(expectedImports), ", ".join(imports))
            else:
                self.__error(node.lineno - 1, node.col_offset, "M702",
                             ", ".join(expectedImports))
    
    def __checkPep3101(self):
        """
        Private method to check for old style string formatting.
        """
        for lineno, line in enumerate(self.__source):
            match = self.__pep3101FormatRegex.search(line)
            if match:
                lineLen = len(line)
                pos = line.find('%')
                formatPos = pos
                formatter = '%'
                if line[pos + 1] == "(":
                    pos = line.find(")", pos)
                c = line[pos]
                while c not in "diouxXeEfFgGcrs":
                    pos += 1
                    if pos >= lineLen:
                        break
                    c = line[pos]
                if c in "diouxXeEfFgGcrs":
                    formatter += c
                self.__error(lineno, formatPos, "M601", formatter)
    
    def __checkFormatString(self):
        """
        Private method to check string format strings.
        """
        coding = self.__getCoding()[1]
        if not coding:
            # default to utf-8
            coding = "utf-8"
        
        visitor = TextVisitor()
        visitor.visit(self.__tree)
        for node in visitor.nodes:
            text = node.s
            if sys.version_info[0] > 2 and isinstance(text, bytes):
                try:
                    text = text.decode(coding)
                except UnicodeDecodeError:
                    continue
            fields, implicit, explicit = self.__getFields(text)
            if implicit:
                if node in visitor.calls:
                    self.__error(node.lineno - 1, node.col_offset, "M611")
                else:
                    if node.is_docstring:
                        self.__error(node.lineno - 1, node.col_offset, "M612")
                    else:
                        self.__error(node.lineno - 1, node.col_offset, "M613")
            
            if node in visitor.calls:
                call, strArgs = visitor.calls[node]
                
                numbers = set()
                names = set()
                # Determine which fields require a keyword and which an arg
                for name in fields:
                    fieldMatch = self.FormatFieldRegex.match(name)
                    try:
                        number = int(fieldMatch.group(1))
                    except ValueError:
                        number = -1
                    # negative numbers are considered keywords
                    if number >= 0:
                        numbers.add(number)
                    else:
                        names.add(fieldMatch.group(1))
                
                keywords = {keyword.arg for keyword in call.keywords}
                numArgs = len(call.args)
                if strArgs:
                    numArgs -= 1
                if sys.version_info < (3, 5):
                    hasKwArgs = bool(call.kwargs)
                    hasStarArgs = bool(call.starargs)
                else:
                    hasKwArgs = any(kw.arg is None for kw in call.keywords)
                    hasStarArgs = sum(1 for arg in call.args
                                      if isinstance(arg, ast.Starred))
                    
                    if hasKwArgs:
                        keywords.discard(None)
                    if hasStarArgs:
                        numArgs -= 1
                
                # if starargs or kwargs is not None, it can't count the
                # parameters but at least check if the args are used
                if hasKwArgs:
                    if not names:
                        # No names but kwargs
                        self.__error(call.lineno - 1, call.col_offset, "M623")
                if hasStarArgs:
                    if not numbers:
                        # No numbers but args
                        self.__error(call.lineno - 1, call.col_offset, "M624")
                
                if not hasKwArgs and not hasStarArgs:
                    # can actually verify numbers and names
                    for number in sorted(numbers):
                        if number >= numArgs:
                            self.__error(call.lineno - 1, call.col_offset,
                                         "M621", number)
                    
                    for name in sorted(names):
                        if name not in keywords:
                            self.__error(call.lineno - 1, call.col_offset,
                                         "M622", name)
                
                for arg in range(numArgs):
                    if arg not in numbers:
                        self.__error(call.lineno - 1, call.col_offset, "M631",
                                     arg)
                
                for keyword in keywords:
                    if keyword not in names:
                        self.__error(call.lineno - 1, call.col_offset, "M632",
                                     keyword)
                
                if implicit and explicit:
                    self.__error(call.lineno - 1, call.col_offset, "M625")
    
    def __getFields(self, string):
        """
        Private method to extract the format field information.
        
        @param string format string to be parsed
        @type str
        @return format field information as a tuple with fields, implicit
            field definitions present and explicit field definitions present
        @rtype tuple of set of str, bool, bool
        """
        fields = set()
        cnt = itertools.count()
        implicit = False
        explicit = False
        try:
            for _literal, field, spec, conv in self.Formatter.parse(string):
                if field is not None and (conv is None or conv in 'rsa'):
                    if not field:
                        field = str(next(cnt))
                        implicit = True
                    else:
                        explicit = True
                    fields.add(field)
                    fields.update(parsedSpec[1]
                                  for parsedSpec in self.Formatter.parse(spec)
                                  if parsedSpec[1] is not None)
        except ValueError:
            return set(), False, False
        else:
            return fields, implicit, explicit
    
    def __checkBuiltins(self):
        """
        Private method to check, if built-ins are shadowed.
        """
        functionDefs = [ast.FunctionDef]
        try:
            functionDefs.append(ast.AsyncFunctionDef)
        except AttributeError:
            pass
        
        ignoreBuiltinAssignments = self.__args.get(
            "BuiltinsChecker", self.__defaultArgs["BuiltinsChecker"])
        
        for node in ast.walk(self.__tree):
            if isinstance(node, ast.Assign):
                # assign statement
                for element in node.targets:
                    if isinstance(element, ast.Name) and \
                       element.id in self.__builtins:
                        value = node.value
                        if isinstance(value, ast.Name) and \
                           element.id in ignoreBuiltinAssignments and \
                           value.id in ignoreBuiltinAssignments[element.id]:
                            # ignore compatibility assignments
                            continue
                        self.__error(element.lineno - 1, element.col_offset,
                                     "M131", element.id)
                    elif isinstance(element, (ast.Tuple, ast.List)):
                        for tupleElement in element.elts:
                            if isinstance(tupleElement, ast.Name) and \
                               tupleElement.id in self.__builtins:
                                self.__error(tupleElement.lineno - 1,
                                             tupleElement.col_offset,
                                             "M131", tupleElement.id)
            elif isinstance(node, ast.For):
                # for loop
                target = node.target
                if isinstance(target, ast.Name) and \
                   target.id in self.__builtins:
                    self.__error(target.lineno - 1, target.col_offset,
                                 "M131", target.id)
                elif isinstance(target, (ast.Tuple, ast.List)):
                    for element in target.elts:
                        if isinstance(element, ast.Name) and \
                           element.id in self.__builtins:
                            self.__error(element.lineno - 1,
                                         element.col_offset,
                                         "M131", element.id)
            elif any(isinstance(node, functionDef)
                     for functionDef in functionDefs):
                # (asynchronous) function definition
                if sys.version_info >= (3, 0):
                    for arg in node.args.args:
                        if isinstance(arg, ast.arg) and \
                           arg.arg in self.__builtins:
                            self.__error(arg.lineno - 1, arg.col_offset,
                                         "M132", arg.arg)
                else:
                    for arg in node.args.args:
                        if isinstance(arg, ast.Name) and \
                           arg.id in self.__builtins:
                            self.__error(arg.lineno - 1, arg.col_offset,
                                         "M132", arg.id)
    
    def __checkComprehensions(self):
        """
        Private method to check some comprehension related things.
        """
        for node in ast.walk(self.__tree):
            if (isinstance(node, ast.Call) and
               len(node.args) == 1 and
               isinstance(node.func, ast.Name)):
                if (isinstance(node.args[0], ast.GeneratorExp) and
                        node.func.id in ('list', 'set', 'dict')):
                    errorCode = {
                        "dict": "M193",
                        "list": "M191",
                        "set": "M192",
                    }[node.func.id]
                    self.__error(node.lineno - 1, node.col_offset, errorCode)

                elif (isinstance(node.args[0], ast.ListComp) and
                      node.func.id in ('set', 'dict')):
                    errorCode = {
                        'dict': 'M195',
                        'set': 'M194',
                    }[node.func.id]
                    self.__error(node.lineno - 1, node.col_offset, errorCode)

                elif (isinstance(node.args[0], ast.List) and
                      node.func.id in ('set', 'dict')):
                    errorCode = {
                        'dict': 'M197',
                        'set': 'M196',
                    }[node.func.id]
                    self.__error(node.lineno - 1, node.col_offset, errorCode)

                elif (isinstance(node.args[0], ast.ListComp) and
                      node.func.id in ('all', 'any', 'frozenset', 'max', 'min',
                                       'sorted', 'sum', 'tuple',)):
                    self.__error(node.lineno - 1, node.col_offset, "M198",
                                 node.func.id)
    
    def __checkMutableDefault(self):
        """
        Private method to check for use of mutable types as default arguments.
        """
        mutableTypes = (
            ast.Call,
            ast.Dict,
            ast.List,
            ast.Set,
        )
        mutableCalls = (
            "Counter",
            "OrderedDict",
            "collections.Counter",
            "collections.OrderedDict",
            "collections.defaultdict",
            "collections.deque",
            "defaultdict",
            "deque",
            "dict",
            "list",
            "set",
        )
        immutableCalls = (
            "tuple",
            "frozenset",
        )
        functionDefs = [ast.FunctionDef]
        try:
            functionDefs.append(ast.AsyncFunctionDef)
        except AttributeError:
            pass
        
        for node in ast.walk(self.__tree):
            if any(isinstance(node, functionDef)
                   for functionDef in functionDefs):
                for default in node.args.defaults + node.args.kw_defaults:
                    if any(isinstance(default, mutableType)
                           for mutableType in mutableTypes):
                        typeName = type(default).__name__
                        if isinstance(default, ast.Call):
                            callPath = '.'.join(composeCallPath(default.func))
                            if callPath in mutableCalls:
                                self.__error(default.lineno - 1,
                                             default.col_offset,
                                             "M823", callPath + "()")
                            elif callPath not in immutableCalls:
                                self.__error(default.lineno - 1,
                                             default.col_offset,
                                             "M822", typeName)
                        else:
                            self.__error(default.lineno - 1,
                                         default.col_offset,
                                         "M821", typeName)
    
    def __dictShouldBeChecked(self, node):
        """
        Private function to test, if the node should be checked.
        
        @param node reference to the AST node
        @return flag indicating to check the node
        @rtype bool
        """
        if not all(isinstance(key, ast.Str) for key in node.keys):
            return False
        
        if "__IGNORE_WARNING__" in self.__source[node.lineno - 1] or \
           "__IGNORE_WARNING_M201__" in self.__source[node.lineno - 1]:
            return False
        
        lineNumbers = [key.lineno for key in node.keys]
        return len(lineNumbers) == len(set(lineNumbers))
    
    def __checkDictWithSortedKeys(self):
        """
        Private method to check, if dictionary keys appear in sorted order.
        """
        for node in ast.walk(self.__tree):
            if isinstance(node, ast.Dict) and self.__dictShouldBeChecked(node):
                for key1, key2 in zip(node.keys, node.keys[1:]):
                    if key2.s < key1.s:
                        self.__error(key2.lineno - 1, key2.col_offset,
                                     "M201", key2.s, key1.s)
    
    def __checkLogging(self):
        """
        Private method to check logging statements.
        """
        visitor = LoggingVisitor()
        visitor.visit(self.__tree)
        for node, reason in visitor.violations:
            self.__error(node.lineno - 1, node.col_offset, reason)
    
    def __checkGettext(self):
        """
        Private method to check the 'gettext' import statement.
        """
        for node in ast.walk(self.__tree):
            if isinstance(node, ast.ImportFrom) and \
               any(name.asname == '_' for name in node.names):
                self.__error(node.lineno - 1, node.col_offset, "M711",
                             node.names[0].name)
    
    def __checkBugBear(self):
        """
        Private method for bugbear checks.
        """
        visitor = BugBearVisitor()
        visitor.visit(self.__tree)
        for violation in visitor.violations:
            node = violation[0]
            reason = violation[1]
            params = violation[2:]
            self.__error(node.lineno - 1, node.col_offset, reason, *params)
    
    def __checkReturn(self):
        """
        Private method to check return statements.
        """
        visitor = ReturnVisitor()
        visitor.visit(self.__tree)
        for violation in visitor.violations:
            node = violation[0]
            reason = violation[1]
            self.__error(node.lineno - 1, node.col_offset, reason)


class TextVisitor(ast.NodeVisitor):
    """
    Class implementing a node visitor for bytes and str instances.

    It tries to detect docstrings as string of the first expression of each
    module, class or function.
    """
    # modelled after the string format flake8 extension
    
    def __init__(self):
        """
        Constructor
        """
        super(TextVisitor, self).__init__()
        self.nodes = []
        self.calls = {}

    def __addNode(self, node):
        """
        Private method to add a node to our list of nodes.
        
        @param node reference to the node to add
        @type ast.AST
        """
        if not hasattr(node, 'is_docstring'):
            node.is_docstring = False
        self.nodes.append(node)

    def __isBaseString(self, node):
        """
        Private method to determine, if a node is a base string node.
        
        @param node reference to the node to check
        @type ast.AST
        @return flag indicating a base string
        @rtype bool
        """
        typ = (ast.Str,)
        if sys.version_info[0] > 2:
            typ += (ast.Bytes,)
        return isinstance(node, typ)

    def visit_Str(self, node):
        """
        Public method to record a string node.
        
        @param node reference to the string node
        @type ast.Str
        """
        self.__addNode(node)

    def visit_Bytes(self, node):
        """
        Public method to record a bytes node.
        
        @param node reference to the bytes node
        @type ast.Bytes
        """
        self.__addNode(node)

    def __visitDefinition(self, node):
        """
        Private method handling class and function definitions.
        
        @param node reference to the node to handle
        @type ast.FunctionDef, ast.AsyncFunctionDef or ast.ClassDef
        """
        # Manually traverse class or function definition
        # * Handle decorators normally
        # * Use special check for body content
        # * Don't handle the rest (e.g. bases)
        for decorator in node.decorator_list:
            self.visit(decorator)
        self.__visitBody(node)

    def __visitBody(self, node):
        """
        Private method to traverse the body of the node manually.

        If the first node is an expression which contains a string or bytes it
        marks that as a docstring.
        
        @param node reference to the node to traverse
        @type ast.AST
        """
        if (node.body and isinstance(node.body[0], ast.Expr) and
                self.__isBaseString(node.body[0].value)):
            node.body[0].value.is_docstring = True

        for subnode in node.body:
            self.visit(subnode)

    def visit_Module(self, node):
        """
        Public method to handle a module.
        
        @param node reference to the node to handle
        @type ast.Module
        """
        self.__visitBody(node)

    def visit_ClassDef(self, node):
        """
        Public method to handle a class definition.
        
        @param node reference to the node to handle
        @type ast.ClassDef
        """
        # Skipped nodes: ('name', 'bases', 'keywords', 'starargs', 'kwargs')
        self.__visitDefinition(node)

    def visit_FunctionDef(self, node):
        """
        Public method to handle a function definition.
        
        @param node reference to the node to handle
        @type ast.FunctionDef
        """
        # Skipped nodes: ('name', 'args', 'returns')
        self.__visitDefinition(node)

    def visit_AsyncFunctionDef(self, node):
        """
        Public method to handle an asynchronous function definition.
        
        @param node reference to the node to handle
        @type ast.AsyncFunctionDef
        """
        # Skipped nodes: ('name', 'args', 'returns')
        self.__visitDefinition(node)

    def visit_Call(self, node):
        """
        Public method to handle a function call.
        
        @param node reference to the node to handle
        @type ast.Call
        """
        if (isinstance(node.func, ast.Attribute) and
                node.func.attr == 'format'):
            if self.__isBaseString(node.func.value):
                self.calls[node.func.value] = (node, False)
            elif (isinstance(node.func.value, ast.Name) and
                    node.func.value.id == 'str' and node.args and
                    self.__isBaseString(node.args[0])):
                self.calls[node.args[0]] = (node, True)
        super(TextVisitor, self).generic_visit(node)


class LoggingVisitor(ast.NodeVisitor):
    """
    Class implementing a node visitor to check logging statements.
    """
    LoggingLevels = {
        "debug",
        "critical",
        "error",
        "info",
        "warn",
        "warning",
    }
    
    def __init__(self):
        """
        Constructor
        """
        super(LoggingVisitor, self).__init__()
        
        self.__currentLoggingCall = None
        self.__currentLoggingArgument = None
        self.__currentLoggingLevel = None
        self.__currentExtraKeyword = None
        self.violations = []

    def __withinLoggingStatement(self):
        """
        Private method to check, if we are inside a logging statement.
        
        @return flag indicating we are inside a logging statement
        @rtype bool
        """
        return self.__currentLoggingCall is not None

    def __withinLoggingArgument(self):
        """
        Private method to check, if we are inside a logging argument.
        
        @return flag indicating we are inside a logging argument
        @rtype bool
        """
        return self.__currentLoggingArgument is not None

    def __withinExtraKeyword(self, node):
        """
        Private method to check, if we are inside the extra keyword.
        
        @param node reference to the node to be checked
        @type ast.keyword
        @return flag indicating we are inside the extra keyword
        @rtype bool
        """
        return self.__currentExtraKeyword is not None and \
            self.__currentExtraKeyword != node
    
    def __detectLoggingLevel(self, node):
        """
        Private method to decide whether an AST Call is a logging call.
        
        @param node reference to the node to be processed
        @type ast.Call
        @return logging level
        @rtype str or None
        """
        try:
            if node.func.value.id == "warnings":
                return None
            
            if node.func.attr in LoggingVisitor.LoggingLevels:
                return node.func.attr
        except AttributeError:
            pass
        
        return None

    def __isFormatCall(self, node):
        """
        Private method to check if a function call uses format.

        @param node reference to the node to be processed
        @type ast.Call
        @return flag indicating the function call uses format
        @rtype bool
        """
        try:
            return node.func.attr == "format"
        except AttributeError:
            return False
    
    def visit_Call(self, node):
        """
        Public method to handle a function call.

        Every logging statement and string format is expected to be a function
        call.
        
        @param node reference to the node to be processed
        @type ast.Call
        """
        # we are in a logging statement
        if self.__withinLoggingStatement():
            if self.__withinLoggingArgument() and self.__isFormatCall(node):
                self.violations.append((node, "M651"))
                super(LoggingVisitor, self).generic_visit(node)
                return
        
        loggingLevel = self.__detectLoggingLevel(node)
        
        if loggingLevel and self.__currentLoggingLevel is None:
            self.__currentLoggingLevel = loggingLevel
        
        # we are in some other statement
        if loggingLevel is None:
            super(LoggingVisitor, self).generic_visit(node)
            return
        
        # we are entering a new logging statement
        self.__currentLoggingCall = node
        
        if loggingLevel == "warn":
            self.violations.append((node, "M655"))
        
        for index, child in enumerate(ast.iter_child_nodes(node)):
            if index == 1:
                self.__currentLoggingArgument = child
            if index > 1 and isinstance(child, ast.keyword) and \
               child.arg == "extra":
                self.__currentExtraKeyword = child
            
            super(LoggingVisitor, self).visit(child)
            
            self.__currentLoggingArgument = None
            self.__currentExtraKeyword = None
        
        self.__currentLoggingCall = None
        self.__currentLoggingLevel = None
    
    def visit_BinOp(self, node):
        """
        Public method to handle binary operations while processing the first
        logging argument.
        
        @param node reference to the node to be processed
        @type ast.BinOp
        """
        if self.__withinLoggingStatement() and self.__withinLoggingArgument():
            # handle percent format
            if isinstance(node.op, ast.Mod):
                self.violations.append((node, "M652"))
            
            # handle string concat
            if isinstance(node.op, ast.Add):
                self.violations.append((node, "M653"))
        
        super(LoggingVisitor, self).generic_visit(node)
    
    def visit_JoinedStr(self, node):
        """
        Public method to handle f-string arguments.
        
        @param node reference to the node to be processed
        @type ast.JoinedStr
        """
        if sys.version_info >= (3, 6):
            if self.__withinLoggingStatement():
                if any(isinstance(i, ast.FormattedValue) for i in node.values):
                    if self.__withinLoggingArgument():
                        self.violations.append((node, "M654"))
                        
                        super(LoggingVisitor, self).generic_visit(node)


class BugBearVisitor(ast.NodeVisitor):
    """
    Class implementing a node visitor to check for various topics.
    """
    #
    # This class was implemented along the BugBear flake8 extension (v 19.3.0).
    # Original: Copyright (c) 2016 Ɓukasz Langa
    #
    
    NodeWindowSize = 4
    
    def __init__(self):
        """
        Constructor
        """
        super(BugBearVisitor, self).__init__()
        
        self.__nodeStack = []
        self.__nodeWindow = []
        self.violations = []
    
    def visit(self, node):
        """
        Public method to traverse a given AST node.
        
        @param node AST node to be traversed
        @type ast.Node
        """
        self.__nodeStack.append(node)
        self.__nodeWindow.append(node)
        self.__nodeWindow = \
            self.__nodeWindow[-BugBearVisitor.NodeWindowSize:]
        
        super(BugBearVisitor, self).visit(node)
        
        self.__nodeStack.pop()
    
    def visit_UAdd(self, node):
        """
        Public method to handle unary additions.
        
        @param node reference to the node to be processed
        @type ast.UAdd
        """
        trailingNodes = list(map(type, self.__nodeWindow[-4:]))
        if trailingNodes == [ast.UnaryOp, ast.UAdd, ast.UnaryOp, ast.UAdd]:
            originator = self.__nodeWindow[-4]
            self.violations.append((originator, "M501"))
        
        self.generic_visit(node)
    
    def visit_Call(self, node):
        """
        Public method to handle a function call.
        
        @param node reference to the node to be processed
        @type ast.Call
        """
        if sys.version_info >= (3, 0):
            validPaths = ("six", "future.utils", "builtins")
            methodsDict = {
                "M521": ("iterkeys", "itervalues", "iteritems", "iterlists"),
                "M522": ("viewkeys", "viewvalues", "viewitems", "viewlists"),
                "M523": ("next",),
            }
        else:
            validPaths = ()
            methodsDict = {}
        
        if isinstance(node.func, ast.Attribute):
            for code, methods in methodsDict.items():
                if node.func.attr in methods:
                    callPath = ".".join(composeCallPath(node.func.value))
                    if callPath not in validPaths:
                        self.violations.append((node, code))
                    break
            else:
                self.__checkForM502(node)
        else:
            try:
                # bad super() call
                if isinstance(node.func, ast.Name) and node.func.id == "super":
                    args = node.args
                    if (
                        len(args) == 2 and
                        isinstance(args[0], ast.Attribute) and
                        isinstance(args[0].value, ast.Name) and
                        args[0].value.id == 'self' and
                        args[0].attr == '__class__'
                    ):
                        self.violations.append((node, "M509"))
                
                # bad getattr and setattr
                if (
                    node.func.id in ("getattr", "hasattr") and
                    node.args[1].s == "__call__"
                ):
                    self.violations.append((node, "M511"))
                if (
                    node.func.id == "getattr" and
                    len(node.args) == 2 and
                    isinstance(node.args[1], ast.Str)
                ):
                    self.violations.append((node, "M512"))
                elif (
                    node.func.id == "setattr" and
                    len(node.args) == 3 and
                    isinstance(node.args[1], ast.Str)
                ):
                    self.violations.append((node, "M513"))
            except (AttributeError, IndexError):
                pass

            self.generic_visit(node)
    
    def visit_Attribute(self, node):
        """
        Public method to handle attributes.
        
        @param node reference to the node to be processed
        @type ast.Attribute
        """
        callPath = list(composeCallPath(node))
        
        if '.'.join(callPath) == 'sys.maxint' and sys.version_info >= (3, 0):
            self.violations.append((node, "M504"))
        
        elif len(callPath) == 2 and callPath[1] == 'message' and \
                sys.version_info >= (2, 6):
            name = callPath[0]
            for elem in reversed(self.__nodeStack[:-1]):
                if isinstance(elem, ast.ExceptHandler) and elem.name == name:
                    self.violations.append((node, "M505"))
                    break
    
    def visit_Assign(self, node):
        """
        Public method to handle assignments.
        
        @param node reference to the node to be processed
        @type ast.Assign
        """
        if isinstance(self.__nodeStack[-2], ast.ClassDef):
            # By using 'hasattr' below we're ignoring starred arguments, slices
            # and tuples for simplicity.
            assignTargets = {t.id for t in node.targets if hasattr(t, 'id')}
            if '__metaclass__' in assignTargets and sys.version_info >= (3, 0):
                self.violations.append((node, "M524"))
        
        elif len(node.targets) == 1:
            target = node.targets[0]
            if isinstance(target, ast.Attribute) and \
               isinstance(target.value, ast.Name):
                if (target.value.id, target.attr) == ('os', 'environ'):
                    self.violations.append((node, "M506"))
        
        self.generic_visit(node)
    
    def visit_For(self, node):
        """
        Public method to handle 'for' statements.
        
        @param node reference to the node to be processed
        @type ast.For
        """
        self.__checkForM507(node)
        
        self.generic_visit(node)
    
    def visit_Assert(self, node):
        """
        Public method to handle 'assert' statements.
        
        @param node reference to the node to be processed
        @type ast.Assert
        """
        if isinstance(node.test, ast.NameConstant) and \
           node.test.value is False:
            self.violations.append((node, "M503"))
        
        self.generic_visit(node)
    
    def visit_JoinedStr(self, node):
        """
        Public method to handle f-string arguments.
        
        @param node reference to the node to be processed
        @type ast.JoinedStr
        """
        if sys.version_info >= (3, 6):
            for value in node.values:
                if isinstance(value, ast.FormattedValue):
                    return
            
            self.violations.append((node, "M508"))
    
    def __checkForM502(self, node):
        """
        Private method to check the use of *strip().
        
        @param node reference to the node to be processed
        @type ast.Call
        """
        if node.func.attr not in ("lstrip", "rstrip", "strip"):
            return          # method name doesn't match
        
        if len(node.args) != 1 or not isinstance(node.args[0], ast.Str):
            return          # used arguments don't match the builtin strip
        
        s = node.args[0].s
        if len(s) == 1:
            return          # stripping just one character
        
        if len(s) == len(set(s)):
            return          # no characters appear more than once

        self.violations.append((node, "M502"))
    
    def __checkForM507(self, node):
        """
        Private method to check for unused loop variables.
        
        @param node reference to the node to be processed
        @type ast.For
        """
        targets = NameFinder()
        targets.visit(node.target)
        ctrlNames = set(filter(lambda s: not s.startswith('_'),
                               targets.getNames()))
        body = NameFinder()
        for expr in node.body:
            body.visit(expr)
        usedNames = set(body.getNames())
        for name in sorted(ctrlNames - usedNames):
            n = targets.getNames()[name][0]
            self.violations.append((n, "M507", name))


class NameFinder(ast.NodeVisitor):
    """
    Class to extract a name out of a tree of nodes.
    """
    def __init__(self):
        """
        Constructor
        """
        super(NameFinder, self).__init__()
        
        self.__names = {}

    def visit_Name(self, node):
        """
        Public method to handle 'Name' nodes.
        
        @param node reference to the node to be processed
        @type ast.Name
        """
        self.__names.setdefault(node.id, []).append(node)

    def visit(self, node):
        """
        Public method to traverse a given AST node.
        
        @param node AST node to be traversed
        @type ast.Node
        """
        if isinstance(node, list):
            for elem in node:
                super(NameFinder, self).visit(elem)
        else:
            super(NameFinder, self).visit(node)
    
    def getNames(self):
        """
        Public method to return the extracted names and Name nodes.
        
        @return dictionary containing the names as keys and the list of nodes
        @rtype dict
        """
        return self.__names


class ReturnVisitor(ast.NodeVisitor):
    """
    Class implementing a node visitor to check return statements.
    """
    Assigns = 'assigns'
    Refs = 'refs'
    Returns = 'returns'
    
    def __init__(self):
        """
        Constructor
        """
        super(ReturnVisitor, self).__init__()
        
        self.__stack = []
        self.violations = []
    
    @property
    def assigns(self):
        """
        Public method to get the Assign nodes.
        
        @return dictionary containing the node name as key and line number
            as value
        @rtype dict
        """
        return self.__stack[-1][ReturnVisitor.Assigns]
    
    @property
    def refs(self):
        """
        Public method to get the References nodes.
        
        @return dictionary containing the node name as key and line number
            as value
        @rtype dict
        """
        return self.__stack[-1][ReturnVisitor.Refs]
    
    @property
    def returns(self):
        """
        Public method to get the Return nodes.
        
        @return dictionary containing the node name as key and line number
            as value
        @rtype dict
        """
        return self.__stack[-1][ReturnVisitor.Returns]
    
    def __visitWithStack(self, node):
        """
        Private method to traverse a given function node using a stack.
        
        @param node AST node to be traversed
        @type ast.FunctionDef or ast.AsyncFunctionDef
        """
        self.__stack.append({
            ReturnVisitor.Assigns: defaultdict(list),
            ReturnVisitor.Refs: defaultdict(list),
            ReturnVisitor.Returns: []
        })
        
        self.generic_visit(node)
        self.__checkFunction(node)
        self.__stack.pop()
    
    def visit_FunctionDef(self, node):
        """
        Public method to handle a function definition.
        
        @param node reference to the node to handle
        @type ast.FunctionDef
        """
        self.__visitWithStack(node)
    
    def visit_AsyncFunctionDef(self, node):
        """
        Public method to handle a function definition.
        
        @param node reference to the node to handle
        @type ast.AsyncFunctionDef
        """
        self.__visitWithStack(node)
    
    def visit_Return(self, node):
        """
        Public method to handle a return node.
        
        @param node reference to the node to handle
        @type ast.Return
        """
        self.returns.append(node)
        self.generic_visit(node)
    
    def visit_Assign(self, node):
        """
        Public method to handle an assign node.
        
        @param node reference to the node to handle
        @type ast.Assign
        """
        if not self.__stack:
            return
        
        for target in node.targets:
            self.__visitAssignTarget(target)
        self.generic_visit(node.value)
    
    def visit_Name(self, node):
        """
        Public method to handle a name node.
        
        @param node reference to the node to handle
        @type ast.Name
        """
        if self.__stack:
            self.refs[node.id].append(node.lineno)
    
    def __visitAssignTarget(self, node):
        """
        Private method to handle an assign target node.
        
        @param node reference to the node to handle
        @type ast.AST
        """
        if isinstance(node, ast.Tuple):
            for elt in node.elts:
                self.__visitAssignTarget(elt)
            return
        
        if isinstance(node, ast.Name):
            self.assigns[node.id].append(node.lineno)
            return
        
        self.generic_visit(node)
    
    def __checkFunction(self, node):
        """
        Private method to check a function definition node.
        
        @param node reference to the node to check
        @type ast.AsyncFunctionDef or ast.FunctionDef
        """
        if not self.returns or not node.body:
            return
        
        if len(node.body) == 1 and isinstance(node.body[-1], ast.Return):
            # skip functions that consist of `return None` only
            return
        
        if not self.__resultExists():
            self.__checkUnnecessaryReturnNone()
            return
        
        self.__checkImplicitReturnValue()
        self.__checkImplicitReturn(node.body[-1])
        
        for n in self.returns:
            if n.value:
                self.__checkUnnecessaryAssign(n.value)
    
    def __isNone(self, node):
        """
        Private method to check, if a node value is None.
        
        @param node reference to the node to check
        @type ast.AST
        @return flag indicating the node contains a None value
        """
        try:
            return isinstance(node, ast.NameConstant) and node.value is None
        except AttributeError:
            # try Py2
            return isinstance(node, ast.Name) and node.id == "None"
    
    def __resultExists(self):
        """
        Private method to check the existance of a return result.
        
        @return flag indicating the existence of a return result
        @rtype bool
        """
        for node in self.returns:
            value = node.value
            if value and not self.__isNone(value):
                return True
        
        return False
    
    def __checkImplicitReturnValue(self):
        """
        Private method to check for implicit return values.
        """
        for node in self.returns:
            if not node.value:
                self.violations.append((node, "M832"))
    
    def __checkUnnecessaryReturnNone(self):
        """
        Private method to check for an unnecessary 'return None' statement.
        """
        for node in self.returns:
            if self.__isNone(node.value):
                self.violations.append((node, "M831"))
    
    def __checkImplicitReturn(self, node):
        """
        Private method to check for an implicit return statement.
        
        @param node reference to the node to check
        @type ast.AST
        """
        if isinstance(node, ast.If):
            if not node.body or not node.orelse:
                self.violations.append((node, "M833"))
                return
            
            self.__checkImplicitReturn(node.body[-1])
            self.__checkImplicitReturn(node.orelse[-1])
            return
        
        if isinstance(node, ast.For) and node.orelse:
            self.__checkImplicitReturn(node.orelse[-1])
            return
        
        if isinstance(node, ast.With):
            self.__checkImplicitReturn(node.body[-1])
            return
        
        try:
            okNodes = (ast.Return, ast.Raise, ast.While, ast.Try)
        except AttributeError:
            # Py2
            okNodes = (ast.Return, ast.Raise, ast.While)
        if not isinstance(node, okNodes):
            self.violations.append((node, "M833"))
    
    def __checkUnnecessaryAssign(self, node):
        """
        Private method to check for an unnecessary assign statement.
        
        @param node reference to the node to check
        @type ast.AST
        """
        if not isinstance(node, ast.Name):
            return
        
        varname = node.id
        returnLineno = node.lineno
        
        if varname not in self.assigns:
            return
        
        if varname not in self.refs:
            self.violations.append((node, "M834"))
            return
        
        if self.__hasRefsBeforeNextAssign(varname, returnLineno):
            return
        
        self.violations.append((node, "M834"))

    def __hasRefsBeforeNextAssign(self, varname, returnLineno):
        """
        Private method to check for references before a following assign
        statement.
        
        @param varname variable name to check for
        @type str
        @param returnLineno line number of the return statement
        @type int
        @return flag indicating the existence of references
        @rtype bool
        """
        beforeAssign = 0
        afterAssign = None
        
        for lineno in sorted(self.assigns[varname]):
            if lineno > returnLineno:
                afterAssign = lineno
                break
            
            if lineno <= returnLineno:
                beforeAssign = lineno
        
        for lineno in self.refs[varname]:
            if lineno == returnLineno:
                continue
            
            if afterAssign:
                if beforeAssign < lineno <= afterAssign:
                    return True
            
            elif beforeAssign < lineno:
                return True
        
        return False
#
# eflag: noqa = M702

eric ide

mercurial