eric6/Plugins/CheckerPlugins/CodeStyleChecker/Simplify/SimplifyNodeVisitor.py

Thu, 01 Apr 2021 19:48:36 +0200

author
Detlev Offenbach <detlev@die-offenbachs.de>
date
Thu, 01 Apr 2021 19:48:36 +0200
changeset 8189
17df5c8df8c1
child 8191
9125da0c227e
permissions
-rw-r--r--

Code Style Checker
- continued to implement checkers for potential code simplifications

# -*- coding: utf-8 -*-

# Copyright (c) 2021 Detlev Offenbach <detlev@die-offenbachs.de>
#

"""
Module implementing a node visitor checking for code that could be simplified.
"""

import ast
import collections

try:
    from ast import unparse
except ImportError:
    # Python < 3.9
    from .ast_unparse import unparse

######################################################################
## The following code is derived from the flake8-simplify package.
##
## Original License:
##
## MIT License
##
## Copyright (c) 2020 Martin Thoma
##
## Permission is hereby granted, free of charge, to any person obtaining a copy
## of this software and associated documentation files (the "Software"), to
## deal in the Software without restriction, including without limitation the
## rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
## sell copies of the Software, and to permit persons to whom the Software is
## furnished to do so, subject to the following conditions:
##
## The above copyright notice and this permission notice shall be included in
## all copies or substantial portions of the Software.
##
## THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
## IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
## FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
## AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
## LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
## FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
## IN THE SOFTWARE.
######################################################################

BOOL_CONST_TYPES = (ast.Constant, ast.NameConstant)
AST_CONST_TYPES = (ast.Constant, ast.NameConstant, ast.Str, ast.Num)
STR_TYPES = (ast.Constant, ast.Str)


class SimplifyNodeVisitor(ast.NodeVisitor):
    """
    Class to traverse the AST node tree and check for code that can be
    simplified.
    """
    def __init__(self, errorCallback):
        """
        Constructor
        
        @param errorCallback callback function to register an error
        @type func
        """
        super(SimplifyNodeVisitor, self).__init__()
        
        self.__error = errorCallback
    
    def visit_Expr(self, node):
        """
        Public method to process an Expr node.
        
        @param node reference to the Expr node
        @type ast.Expr
        """
        self.__check112(node)
        
        self.generic_visit(node)
    
    def visit_BoolOp(self, node):
        """
        Public method to process a BoolOp node.
        
        @param node reference to the BoolOp node
        @type ast.BoolOp
        """
        self.__check101(node)
        self.__check109(node)
        
        self.generic_visit(node)
    
    def visit_If(self, node):
        """
        Public method to process an If node.
        
        @param node reference to the If node
        @type ast.If
        """
        self.__check102(node)
        self.__check103(node)
        self.__check106(node)
        self.__check108(node)
        
        self.generic_visit(node)
    
    def visit_For(self, node):
        """
        Public method to process a For node.
        
        @param node reference to the For node
        @type ast.For
        """
        self.__check104(node)
        self.__check110_111(node)
        
        self.generic_visit(node)
    
    def visit_Try(self, node):
        """
        Public method to process a Try node.
        
        @param node reference to the Try node
        @type ast.Try
        """
        self.__check105(node)
        self.__check107(node)
        
        self.generic_visit(node)
    
    #############################################################
    ## Helper methods for the various checkers below
    #############################################################
    
    def __getDuplicatedIsinstanceCall(self, node):
        """
        Private method to get a list of isinstance arguments which could
        be combined.
        
        @param node reference to the AST node to be inspected
        @type ast.BoolOp
        @return list of variable names of duplicated isinstance calls
        @rtype list of str
        """
        counter = collections.defaultdict(int)
        
        for call in node.values:
            # Ensure this is a call of the built-in isinstance() function.
            if not isinstance(call, ast.Call) or len(call.args) != 2:
                continue
            functionName = unparse(call.func)
            if functionName != "isinstance":
                continue
            
            arg0Name = unparse(call.args[0])
            counter[arg0Name] += 1
        
        return [name for name, count in counter.items() if count > 1]
    
    #############################################################
    ## Methods to check for possible code simplifications below
    #############################################################
    
    def __check101(self, node):
        """
        Private method to check for duplicate isinstance() calls.
        
        @param node reference to the AST node to be checked
        @type ast.BoolOp
        """
        if isinstance(node.op, ast.Or):
            for variable in self.__getDuplicatedIsinstanceCall(node):
                self.__error(node.lineno - 1, node.col_offset, "Y101",
                             variable)
    
    def __check102(self, node):
        """
        Private method to check for nested if statements without else blocks.
        
        @param node reference to the AST node to be checked
        @type ast.If
        """
        # ## Pattern 1
        # if a: <---
        #     if b: <---
        #         c
        isPattern1 = (
            node.orelse == [] and
            len(node.body) == 1 and
            isinstance(node.body[0], ast.If) and
            node.body[0].orelse == []
        )
        # ## Pattern 2
        # if a: < irrelvant for here
        #     pass
        # elif b:  <--- this is treated like a nested block
        #     if c: <---
        #         d
        if isPattern1:
            self.__error(node.lineno - 1, node.col_offset, "Y102")
    
    def __check103(self, node):
        """
        Private method to check for calls that wrap a condition to return
        a bool.
        
        @param node reference to the AST node to be checked
        @type ast.If
        """
        # if cond:
        #     return True
        # else:
        #     return False
        if not (
            len(node.body) != 1 or
            not isinstance(node.body[0], ast.Return) or
            not isinstance(node.body[0].value, BOOL_CONST_TYPES) or
            not (
                node.body[0].value.value is True or
                node.body[0].value.value is False
            ) or
            len(node.orelse) != 1 or
            not isinstance(node.orelse[0], ast.Return) or
            not isinstance(node.orelse[0].value, BOOL_CONST_TYPES) or
            not (
                node.orelse[0].value.value is True or
                node.orelse[0].value.value is False
            )
        ):
            condition = unparse(node.test)
            self.__error(node.lineno - 1, node.col_offset, "Y103", condition)
    
    def __check104(self, node):
        """
        Private method to check for "iterate and yield" patterns.
        
        @param node reference to the AST node to be checked
        @type ast.For
        """
        # for item in iterable:
        #     yield item
        if not (
            len(node.body) != 1 or
            not isinstance(node.body[0], ast.Expr) or
            not isinstance(node.body[0].value, ast.Yield) or
            not isinstance(node.target, ast.Name) or
            not isinstance(node.body[0].value.value, ast.Name) or
            node.target.id != node.body[0].value.value.id or
            node.orelse != []
        ):
            iterable = unparse(node.iter)
            self.__error(node.lineno - 1, node.col_offset, "Y104", iterable)
    
    def __check105(self, node):
        """
        Private method to check for "try-except-pass" patterns.
        
        @param node reference to the AST node to be checked
        @type ast.Try
        """
        # try:
        #     foo()
        # except ValueError:
        #     pass
        if not (
            len(node.body) != 1 or
            len(node.handlers) != 1 or
            not isinstance(node.handlers[0], ast.ExceptHandler) or
            len(node.handlers[0].body) != 1 or
            not isinstance(node.handlers[0].body[0], ast.Pass) or
            node.orelse != []
        ):
            if node.handlers[0].type is None:
                exception = "Exception"
            else:
                exception = unparse(node.handlers[0].type)
            self.__error(node.lineno - 1, node.col_offset, "Y105", exception)
    
    def __check106(self, node):
        """
        Private method to check for calls where an exception is raised in else.
        
        @param node reference to the AST node to be checked
        @type ast.If
        """
        # if cond:
        #     return True
        # else:
        #     raise Exception
        just_one = (
            len(node.body) == 1 and
            len(node.orelse) >= 1 and
            isinstance(node.orelse[-1], ast.Raise) and
            not isinstance(node.body[-1], ast.Raise)
        )
        many = (
            len(node.body) > 2 * len(node.orelse) and
            len(node.orelse) >= 1 and
            isinstance(node.orelse[-1], ast.Raise) and
            not isinstance(node.body[-1], ast.Raise)
        )
        if just_one or many:
            self.__error(node.lineno - 1, node.col_offset, "Y106")
    
    def __check107(self, node):
        """
        Private method to check for calls where try/except and finally have
        'return'.
        
        @param node reference to the AST node to be checked
        @type ast.Try
        """
        # def foo():
        #     try:
        #         1 / 0
        #         return "1"
        #     except:
        #         return "2"
        #     finally:
        #         return "3"
        tryHasReturn = False
        for stmt in node.body:
            if isinstance(stmt, ast.Return):
                tryHasReturn = True
                break

        exceptHasReturn = False
        for stmt2 in node.handlers:
            if isinstance(stmt2, ast.Return):
                exceptHasReturn = True
                break

        finallyHasReturn = False
        finallyReturn = None
        for stmt in node.finalbody:
            if isinstance(stmt, ast.Return):
                finallyHasReturn = True
                finallyReturn = stmt
                break

        if (tryHasReturn or exceptHasReturn) and finallyHasReturn:
            if finallyReturn is not None:
                self.__error(finallyReturn.lineno - 1,
                             finallyReturn.col_offset, "Y107")
    
    def __check108(self, node):
        """
        Private method to check for if-elses which could be a ternary
        operator assignment.
        
        @param node reference to the AST node to be checked
        @type ast.If
        """
        # if a:
        #     b = c
        # else:
        #     b = d
        if (
            len(node.body) == 1 and
            len(node.orelse) == 1 and
            isinstance(node.body[0], ast.Assign) and
            isinstance(node.orelse[0], ast.Assign) and
            len(node.body[0].targets) == 1 and
            len(node.orelse[0].targets) == 1 and
            isinstance(node.body[0].targets[0], ast.Name) and
            isinstance(node.orelse[0].targets[0], ast.Name) and
            node.body[0].targets[0].id == node.orelse[0].targets[0].id
        ):
            assign = unparse(node.body[0].targets[0])
            body = unparse(node.body[0].value)
            cond = unparse(node.test)
            orelse = unparse(node.orelse[0].value)
            self.__error(node.lineno - 1, node.col_offset, "Y108",
                         assign, body, cond, orelse)
    
    def __check109(self, node):
        """
        Private method to check for multiple equalities with the same value
        are combined via "or".
        
        @param node reference to the AST node to be checked
        @type ast.BoolOp
        """
        # if a == b or a == c:
        #     d
        if isinstance(node.op, ast.Or):
            equalities = [
                value
                for value in node.values
                if isinstance(value, ast.Compare) and
                len(value.ops) == 1 and
                isinstance(value.ops[0], ast.Eq)
            ]
            ids = []  # (name, compared_to)
            for eq in equalities:
                if isinstance(eq.left, ast.Name):
                    ids.append((eq.left, eq.comparators[0]))
                if (
                    len(eq.comparators) == 1 and
                    isinstance(eq.comparators[0], ast.Name)
                ):
                    ids.append((eq.comparators[0], eq.left))

            id2count = {}
            for identifier, comparedTo in ids:
                if identifier.id not in id2count:
                    id2count[identifier.id] = []
                id2count[identifier.id].append(comparedTo)
            for value, values in id2count.items():
                if len(values) == 1:
                    continue
                
                self.__error(node.lineno - 1, node.col_offset, "Y109",
                             value, unparse(ast.List(elts=values)),
                             unparse(node))
    
    def __check110_111(self, node):
        """
        Private method to check if any / all could be used.
        
        @param node reference to the AST node to be checked
        @type ast.For
        """
        # for x in iterable:
        #     if check(x):
        #         return True
        # return False
        #
        # for x in iterable:
        #     if check(x):
        #         return False
        # return True
        if (
            len(node.body) == 1
            and isinstance(node.body[0], ast.If)
            and len(node.body[0].body) == 1
            and isinstance(node.body[0].body[0], ast.Return)
            and isinstance(node.body[0].body[0].value, BOOL_CONST_TYPES)
            and hasattr(node.body[0].body[0].value, "value")
        ):
            check = unparse(node.body[0].test)
            target = unparse(node.target)
            iterable = unparse(node.iter)
            if node.body[0].body[0].value.value is True:
                self.__error(node.lineno - 1, node.col_offset, "Y110",
                             check, target, iterable)
            elif node.body[0].body[0].value.value is False:
                check = "not " + check
                if check.startswith("not not"):
                    check = check[len("not not "):]
                self.__error(node.lineno - 1, node.col_offset, "Y111",
                             check, target, iterable)
    
    def __check112(self, node):
        """
        Public method to check for non-capitalized calls to environment
        variables.
        
        @param node reference to the AST node to be checked
        @type ast.Expr
        """
        # os.environ["foo"]
        # os.environ.get("bar")
        isIndexCall = (
            isinstance(node.value, ast.Subscript)
            and isinstance(node.value.value, ast.Attribute)
            and isinstance(node.value.value.value, ast.Name)
            and node.value.value.value.id == "os"
            and node.value.value.attr == "environ"
            and (
                (
                    isinstance(node.value.slice, ast.Index)
                    and isinstance(node.value.slice.value, STR_TYPES)
                )
                or isinstance(node.value.slice, ast.Constant)
            )
        )
        if isIndexCall:
            subscript = node.value
            slice_ = subscript.slice
            if isinstance(slice_, ast.Index):
                # Python < 3.9
                stringPart = slice_.value  # type: ignore
                if isinstance(stringPart, ast.Str):
                    envName = stringPart.s  # Python 3.6 / 3.7 fallback
                else:
                    envName = stringPart.value
            elif isinstance(slice_, ast.Constant):
                # Python 3.9
                envName = slice_.value

            # Check if this has a change
            hasChange = envName != envName.upper()

        isGetCall = (
            isinstance(node.value, ast.Call)
            and isinstance(node.value.func, ast.Attribute)
            and isinstance(node.value.func.value, ast.Attribute)
            and isinstance(node.value.func.value.value, ast.Name)
            and node.value.func.value.value.id == "os"
            and node.value.func.value.attr == "environ"
            and node.value.func.attr == "get"
            and len(node.value.args) in [1, 2]
            and isinstance(node.value.args[0], STR_TYPES)
        )
        if isGetCall:
            call = node.value
            stringPart = call.args[0]
            if isinstance(stringPart, ast.Str):
                envName = stringPart.s  # Python 3.6 / 3.7 fallback
            else:
                envName = stringPart.value
            # Check if this has a change
            hasChange = envName != envName.upper()
        if not (isIndexCall or isGetCall) or not hasChange:
            return
        if isIndexCall:
            original = unparse(node)
            expected = f"os.environ['{envName.upper()}']"
        elif isGetCall:
            original = unparse(node)
            if len(node.value.args) == 1:
                expected = f"os.environ.get('{envName.upper()}')"
            else:
                defaultValue = unparse(node.value.args[1])
                expected = (
                    f"os.environ.get('{envName.upper()}', '{defaultValue}')"
                )
        else:
            return
        
        self.__error(node.lineno - 1, node.col_offset, "Y112", expected,
                     original)

#
# eflag: noqa = M891

eric ide

mercurial