Fri, 02 Apr 2021 18:13:12 +0200
Code Style Checker
- continued to implement checkers for potential code simplifications
# -*- coding: utf-8 -*- # Copyright (c) 2021 Detlev Offenbach <detlev@die-offenbachs.de> # """ Module implementing a node visitor checking for code that could be simplified. """ import ast import collections import itertools try: from ast import unparse except ImportError: # Python < 3.9 from .ast_unparse import unparse ###################################################################### ## The following code is derived from the flake8-simplify package. ## ## Original License: ## ## MIT License ## ## Copyright (c) 2020 Martin Thoma ## ## Permission is hereby granted, free of charge, to any person obtaining a copy ## of this software and associated documentation files (the "Software"), to ## deal in the Software without restriction, including without limitation the ## rights to use, copy, modify, merge, publish, distribute, sublicense, and/or ## sell copies of the Software, and to permit persons to whom the Software is ## furnished to do so, subject to the following conditions: ## ## The above copyright notice and this permission notice shall be included in ## all copies or substantial portions of the Software. ## ## THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR ## IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, ## FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE ## AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER ## LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING ## FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS ## IN THE SOFTWARE. ###################################################################### BOOL_CONST_TYPES = (ast.Constant, ast.NameConstant) AST_CONST_TYPES = (ast.Constant, ast.NameConstant, ast.Str, ast.Num) STR_TYPES = (ast.Constant, ast.Str) class SimplifyNodeVisitor(ast.NodeVisitor): """ Class to traverse the AST node tree and check for code that can be simplified. """ def __init__(self, errorCallback): """ Constructor @param errorCallback callback function to register an error @type func """ super(SimplifyNodeVisitor, self).__init__() self.__error = errorCallback def visit_Expr(self, node): """ Public method to process an Expr node. @param node reference to the Expr node @type ast.Expr """ self.__check112(node) self.generic_visit(node) def visit_BoolOp(self, node): """ Public method to process a BoolOp node. @param node reference to the BoolOp node @type ast.BoolOp """ self.__check101(node) self.__check109(node) self.generic_visit(node) def visit_If(self, node): """ Public method to process an If node. @param node reference to the If node @type ast.If """ self.__check102(node) self.__check103(node) self.__check106(node) self.__check108(node) self.__check114(node) self.__check116(node) self.generic_visit(node) def visit_For(self, node): """ Public method to process a For node. @param node reference to the For node @type ast.For """ self.__check104(node) self.__check110_111(node) self.__check113(node) self.__check118b(node) self.generic_visit(node) def visit_Try(self, node): """ Public method to process a Try node. @param node reference to the Try node @type ast.Try """ self.__check105(node) self.__check107(node) self.generic_visit(node) def visit_Call(self, node): """ Public method to process a Call node. @param node reference to the Call node @type ast.Call """ self.__check115(node) self.generic_visit(node) def visit_With(self, node): """ Public method to process a With node. @param node reference to the With node @type ast.With """ self.__check117(node) self.generic_visit(node) def visit_Compare(self, node): """ Public method to process a Compare node. @param node reference to the Compare node @type ast.Compare """ self.__check118a(node) self.generic_visit(node) def visit_ClassDef(self, node): """ Public method to process a ClassDef node. @param node reference to the ClassDef node @type ast.ClassDef """ self.__check119(node) self.generic_visit(node) ############################################################# ## Helper methods for the various checkers below ############################################################# def __getDuplicatedIsinstanceCall(self, node): """ Private method to get a list of isinstance arguments which could be combined. @param node reference to the AST node to be inspected @type ast.BoolOp @return list of variable names of duplicated isinstance calls @rtype list of str """ counter = collections.defaultdict(int) for call in node.values: # Ensure this is a call of the built-in isinstance() function. if not isinstance(call, ast.Call) or len(call.args) != 2: continue functionName = unparse(call.func) if functionName != "isinstance": continue arg0Name = unparse(call.args[0]) counter[arg0Name] += 1 return [name for name, count in counter.items() if count > 1] def __isConstantIncrease(self, expression): """ Private method check the given expression for being a constant increase. @param expression reference to the expression node @type ast.AugAssign @return flag indicating a constant increase @rtype bool """ return ( isinstance(expression.op, ast.Add) and isinstance(expression.value, (ast.Constant, ast.Num)) ) def __getIfBodyPairs(self, node): """ Private method to extract a list of pairs of test and body for an If node. @param node reference to the If node to be processed @type ast.If @return list of pairs of test and body @rtype list of tuples of (ast.expr, [ast.stmt]) """ pairs = [(node.test, node.body)] orelse = node.orelse while ( isinstance(orelse, list) and len(orelse) == 1 and isinstance(orelse[0], ast.If) ): pairs.append((orelse[0].test, orelse[0].body)) orelse = orelse[0].orelse return pairs def __isSameBody(self, body1, body2): """ Private method check, if the given bodies are equivalent. @param body1 list of statements of the first body @type list of ast.stmt @param body2 list of statements of the second body @type list of ast.stmt @return flag indicating identical bodies @rtype bool """ if len(body1) != len(body2): return False for a, b in zip(body1, body2): try: statementEqual = self.__isStatementEqual(a, b) except RecursionError: # maximum recursion depth statementEqual = False if not statementEqual: return False return True def __isStatementEqual(self, a: ast.stmt, b: ast.stmt) -> bool: """ Private method to check, if two statements are equal. @param a reference to the first statement @type ast.stmt @param b reference to the second statement @type ast.stmt @return flag indicating if the two statements are equal @rtype bool """ if type(a) is not type(b): return False if isinstance(a, ast.AST): for k, v in vars(a).items(): if k in ("lineno", "col_offset", "ctx", "end_lineno", "parent"): continue if not self.__isStatementEqual(v, getattr(b, k)): return False return True elif isinstance(a, list): return all(itertools.starmap(self.__isStatementEqual, zip(a, b))) else: return a == b ############################################################# ## Methods to check for possible code simplifications below ############################################################# def __check101(self, node): """ Private method to check for duplicate isinstance() calls. @param node reference to the AST node to be checked @type ast.BoolOp """ if isinstance(node.op, ast.Or): for variable in self.__getDuplicatedIsinstanceCall(node): self.__error(node.lineno - 1, node.col_offset, "Y101", variable) def __check102(self, node): """ Private method to check for nested if statements without else blocks. @param node reference to the AST node to be checked @type ast.If """ # ## Pattern 1 # if a: <--- # if b: <--- # c isPattern1 = ( node.orelse == [] and len(node.body) == 1 and isinstance(node.body[0], ast.If) and node.body[0].orelse == [] ) # ## Pattern 2 # if a: < irrelvant for here # pass # elif b: <--- this is treated like a nested block # if c: <--- # d if isPattern1: self.__error(node.lineno - 1, node.col_offset, "Y102") def __check103(self, node): """ Private method to check for calls that wrap a condition to return a bool. @param node reference to the AST node to be checked @type ast.If """ # if cond: # return True # else: # return False if not ( len(node.body) != 1 or not isinstance(node.body[0], ast.Return) or not isinstance(node.body[0].value, BOOL_CONST_TYPES) or not ( node.body[0].value.value is True or node.body[0].value.value is False ) or len(node.orelse) != 1 or not isinstance(node.orelse[0], ast.Return) or not isinstance(node.orelse[0].value, BOOL_CONST_TYPES) or not ( node.orelse[0].value.value is True or node.orelse[0].value.value is False ) ): condition = unparse(node.test) self.__error(node.lineno - 1, node.col_offset, "Y103", condition) def __check104(self, node): """ Private method to check for "iterate and yield" patterns. @param node reference to the AST node to be checked @type ast.For """ # for item in iterable: # yield item if not ( len(node.body) != 1 or not isinstance(node.body[0], ast.Expr) or not isinstance(node.body[0].value, ast.Yield) or not isinstance(node.target, ast.Name) or not isinstance(node.body[0].value.value, ast.Name) or node.target.id != node.body[0].value.value.id or node.orelse != [] ): iterable = unparse(node.iter) self.__error(node.lineno - 1, node.col_offset, "Y104", iterable) def __check105(self, node): """ Private method to check for "try-except-pass" patterns. @param node reference to the AST node to be checked @type ast.Try """ # try: # foo() # except ValueError: # pass if not ( len(node.body) != 1 or len(node.handlers) != 1 or not isinstance(node.handlers[0], ast.ExceptHandler) or len(node.handlers[0].body) != 1 or not isinstance(node.handlers[0].body[0], ast.Pass) or node.orelse != [] ): if node.handlers[0].type is None: exception = "Exception" else: exception = unparse(node.handlers[0].type) self.__error(node.lineno - 1, node.col_offset, "Y105", exception) def __check106(self, node): """ Private method to check for calls where an exception is raised in else. @param node reference to the AST node to be checked @type ast.If """ # if cond: # return True # else: # raise Exception just_one = ( len(node.body) == 1 and len(node.orelse) >= 1 and isinstance(node.orelse[-1], ast.Raise) and not isinstance(node.body[-1], ast.Raise) ) many = ( len(node.body) > 2 * len(node.orelse) and len(node.orelse) >= 1 and isinstance(node.orelse[-1], ast.Raise) and not isinstance(node.body[-1], ast.Raise) ) if just_one or many: self.__error(node.lineno - 1, node.col_offset, "Y106") def __check107(self, node): """ Private method to check for calls where try/except and finally have 'return'. @param node reference to the AST node to be checked @type ast.Try """ # def foo(): # try: # 1 / 0 # return "1" # except: # return "2" # finally: # return "3" tryHasReturn = False for stmt in node.body: if isinstance(stmt, ast.Return): tryHasReturn = True break exceptHasReturn = False for stmt2 in node.handlers: if isinstance(stmt2, ast.Return): exceptHasReturn = True break finallyHasReturn = False finallyReturn = None for stmt in node.finalbody: if isinstance(stmt, ast.Return): finallyHasReturn = True finallyReturn = stmt break if (tryHasReturn or exceptHasReturn) and finallyHasReturn: if finallyReturn is not None: self.__error(finallyReturn.lineno - 1, finallyReturn.col_offset, "Y107") def __check108(self, node): """ Private method to check for if-elses which could be a ternary operator assignment. @param node reference to the AST node to be checked @type ast.If """ # if a: # b = c # else: # b = d if ( len(node.body) == 1 and len(node.orelse) == 1 and isinstance(node.body[0], ast.Assign) and isinstance(node.orelse[0], ast.Assign) and len(node.body[0].targets) == 1 and len(node.orelse[0].targets) == 1 and isinstance(node.body[0].targets[0], ast.Name) and isinstance(node.orelse[0].targets[0], ast.Name) and node.body[0].targets[0].id == node.orelse[0].targets[0].id ): assign = unparse(node.body[0].targets[0]) body = unparse(node.body[0].value) cond = unparse(node.test) orelse = unparse(node.orelse[0].value) if len( "{0} = {1} if {2} else {3}".format(assign, body, cond, orelse) ) > 79: self.__error(node.lineno - 1, node.col_offset, "Y108a") else: self.__error(node.lineno - 1, node.col_offset, "Y108b", assign, body, cond, orelse) def __check109(self, node): """ Private method to check for multiple equalities with the same value are combined via "or". @param node reference to the AST node to be checked @type ast.BoolOp """ # if a == b or a == c: # d if isinstance(node.op, ast.Or): equalities = [ value for value in node.values if isinstance(value, ast.Compare) and len(value.ops) == 1 and isinstance(value.ops[0], ast.Eq) ] ids = [] # (name, compared_to) for eq in equalities: if isinstance(eq.left, ast.Name): ids.append((eq.left, eq.comparators[0])) if ( len(eq.comparators) == 1 and isinstance(eq.comparators[0], ast.Name) ): ids.append((eq.comparators[0], eq.left)) id2count = {} for identifier, comparedTo in ids: if identifier.id not in id2count: id2count[identifier.id] = [] id2count[identifier.id].append(comparedTo) for value, values in id2count.items(): if len(values) == 1: continue self.__error(node.lineno - 1, node.col_offset, "Y109", value, unparse(ast.List(elts=values)), unparse(node)) def __check110_111(self, node): """ Private method to check if any / all could be used. @param node reference to the AST node to be checked @type ast.For """ # for x in iterable: # if check(x): # return True # return False # # for x in iterable: # if check(x): # return False # return True if ( len(node.body) == 1 and isinstance(node.body[0], ast.If) and len(node.body[0].body) == 1 and isinstance(node.body[0].body[0], ast.Return) and isinstance(node.body[0].body[0].value, BOOL_CONST_TYPES) and hasattr(node.body[0].body[0].value, "value") ): check = unparse(node.body[0].test) target = unparse(node.target) iterable = unparse(node.iter) if node.body[0].body[0].value.value is True: self.__error(node.lineno - 1, node.col_offset, "Y110", check, target, iterable) elif node.body[0].body[0].value.value is False: check = "not " + check if check.startswith("not not"): check = check[len("not not "):] self.__error(node.lineno - 1, node.col_offset, "Y111", check, target, iterable) def __check112(self, node): """ Private method to check for non-capitalized calls to environment variables. @param node reference to the AST node to be checked @type ast.Expr """ # os.environ["foo"] # os.environ.get("bar") isIndexCall = ( isinstance(node.value, ast.Subscript) and isinstance(node.value.value, ast.Attribute) and isinstance(node.value.value.value, ast.Name) and node.value.value.value.id == "os" and node.value.value.attr == "environ" and ( ( isinstance(node.value.slice, ast.Index) and isinstance(node.value.slice.value, STR_TYPES) ) or isinstance(node.value.slice, ast.Constant) ) ) if isIndexCall: subscript = node.value slice_ = subscript.slice if isinstance(slice_, ast.Index): # Python < 3.9 stringPart = slice_.value # type: ignore if isinstance(stringPart, ast.Str): envName = stringPart.s # Python 3.6 / 3.7 fallback else: envName = stringPart.value elif isinstance(slice_, ast.Constant): # Python 3.9 envName = slice_.value # Check if this has a change hasChange = envName != envName.upper() isGetCall = ( isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Attribute) and isinstance(node.value.func.value, ast.Attribute) and isinstance(node.value.func.value.value, ast.Name) and node.value.func.value.value.id == "os" and node.value.func.value.attr == "environ" and node.value.func.attr == "get" and len(node.value.args) in [1, 2] and isinstance(node.value.args[0], STR_TYPES) ) if isGetCall: call = node.value stringPart = call.args[0] if isinstance(stringPart, ast.Str): envName = stringPart.s # Python 3.6 / 3.7 fallback else: envName = stringPart.value # Check if this has a change hasChange = envName != envName.upper() if not (isIndexCall or isGetCall) or not hasChange: return if isIndexCall: original = unparse(node) expected = f"os.environ['{envName.upper()}']" elif isGetCall: original = unparse(node) if len(node.value.args) == 1: expected = f"os.environ.get('{envName.upper()}')" else: defaultValue = unparse(node.value.args[1]) expected = ( f"os.environ.get('{envName.upper()}', '{defaultValue}')" ) else: return self.__error(node.lineno - 1, node.col_offset, "Y112", expected, original) def __check113(self, node): """ Private method to check for loops in which "enumerate" should be used. @param node reference to the AST node to be checked @type ast.For """ # idx = 0 # for el in iterable: # ... # idx += 1 variableCandidates = [] for expression in node.body: if ( isinstance(expression, ast.AugAssign) and self.__isConstantIncrease(expression) and isinstance(expression.target, ast.Name) ): variableCandidates.append(expression.target) for candidate in variableCandidates: self.__error(candidate.lineno - 1, candidate.col_offset, "Y113", unparse(candidate)) def __check114(self, node): """ Private method to check for alternative if clauses with identical bodies. @param node reference to the AST node to be checked @type ast.If """ # if a: # b # elif c: # b ifBodyPairs = self.__getIfBodyPairs(node) errorPairs = [] for ifbody1, ifbody2 in itertools.combinations(ifBodyPairs, 2): if self.__isSameBody(ifbody1[1], ifbody2[1]): errorPairs.append((ifbody1, ifbody2)) for ifbody1, ifbody2 in errorPairs: self.__error(ifbody1[0].lineno - 1, ifbody1[0].col_offset, "Y114", unparse(ifbody1[0]), unparse(ifbody2[0])) def __check115(self, node): """ Private method to to check for places where open() is called without a context handler. @param node reference to the AST node to be checked @type ast.Call """ # f = open(...) #. .. # (do something with f) # f.close() if ( isinstance(node.func, ast.Name) and node.func.id == "open" and not isinstance(node.parent, ast.withitem) ): self.__error(node.lineno - 1, node.col_offset, "Y115") def __check116(self, node): """ Private method to check for places with 3 or more consecutive if-statements with direct returns. * Each if-statement must be a check for equality with the same variable * Each if-statement must just have a "return" * Else must also just have a return @param node reference to the AST node to be checked @type ast.If """ # if a == "foo": # return "bar" # elif a == "bar": # return "baz" # elif a == "boo": # return "ooh" # else: # return 42 if ( isinstance(node.test, ast.Compare) and isinstance(node.test.left, ast.Name) and len(node.test.ops) == 1 and isinstance(node.test.ops[0], ast.Eq) and len(node.test.comparators) == 1 and isinstance(node.test.comparators[0], AST_CONST_TYPES) and len(node.body) == 1 and isinstance(node.body[0], ast.Return) and len(node.orelse) == 1 and isinstance(node.orelse[0], ast.If) ): variable = node.test.left child = node.orelse[0] elseValue = None if node.body[0].value is not None: bodyValueStr = unparse(node.body[0].value).strip("'") else: bodyValueStr = "None" if isinstance(node.test.comparators[0], ast.Str): keyValuePairs = { node.test.comparators[0].s: bodyValueStr } elif isinstance(node.test.comparators[0], ast.Num): keyValuePairs = { node.test.comparators[0].n: bodyValueStr, } else: keyValuePairs = { node.test.comparators[0].value: bodyValueStr } while child: if not ( isinstance(child.test, ast.Compare) and isinstance(child.test.left, ast.Name) and child.test.left.id == variable.id and len(child.test.ops) == 1 and isinstance(child.test.ops[0], ast.Eq) and len(child.test.comparators) == 1 and isinstance(child.test.comparators[0], AST_CONST_TYPES) and len(child.body) == 1 and isinstance(child.body[0], ast.Return) and len(child.orelse) <= 1 ): return if isinstance(child.test.comparators[0], ast.Str): key = child.test.comparators[0].s elif isinstance(child.test.comparators[0], ast.Num): key = child.test.comparators[0].n else: key = child.test.comparators[0].value keyValuePairs[key] = unparse(child.body[0].value).strip("'") if len(child.orelse) == 1: if isinstance(child.orelse[0], ast.If): child = child.orelse[0] elif isinstance(child.orelse[0], ast.Return): elseValue = unparse(child.orelse[0].value) child = None else: return else: child = None if len(keyValuePairs) < 3: return if elseValue: ret = f"{keyValuePairs}.get({variable.id}, {elseValue})" else: ret = f"{keyValuePairs}.get({variable.id})" self.__error(node.lineno - 1, node.col_offset, "Y116", ret) def __check117(self, node): """ Private method to check for multiple with-statements with same scope. @param node reference to the AST node to be checked @type ast.With """ # with A() as a: # with B() as b: # print("hello") if ( len(node.body) == 1 and isinstance(node.body[0], ast.With) ): withItems = [] for withitem in node.items + node.body[0].items: withItems.append(f"{unparse(withitem)}") mergedWith = f"with {', '.join(withItems)}:" self.__error(node.lineno - 1, node.col_offset, "Y117", mergedWith) def __check118a(self, node): """ Private method to check for usages of "key in dict.keys()". @param node reference to the AST node to be checked @type ast.Compare """ # key in dict.keys() if ( len(node.ops) == 1 and isinstance(node.ops[0], ast.In) and len(node.comparators) == 1 ): callNode = node.comparators[0] if not isinstance(callNode, ast.Call): return attrNode = callNode.func if ( isinstance(callNode.func, ast.Attribute) and callNode.func.attr == "keys" and isinstance(callNode.func.ctx, ast.Load) ): keyStr = unparse(node.left) dictStr = unparse(attrNode.value) self.__error(node.lineno - 1, node.col_offset, "Y118", keyStr, dictStr) def __check118b(self, node): """ Private method to check for usages of "key in dict.keys()". @param node reference to the AST node to be checked @type ast.For """ # for key in dict.keys(): # # do something callNode = node.iter if not isinstance(callNode, ast.Call): return attrNode = callNode.func if ( isinstance(callNode.func, ast.Attribute) and callNode.func.attr == "keys" and isinstance(callNode.func.ctx, ast.Load) ): keyStr = unparse(node.target) dictStr = unparse(attrNode.value) self.__error(node.lineno - 1, node.col_offset, "Y118", keyStr, dictStr) def __check119(self, node): """ Public method to check for classes that should be "dataclasses". @param node reference to the AST node to be checked @type ast.ClassDef """ if ( len(node.decorator_list) == 0 and len(node.bases) == 0 ): dataclassFunctions = [ "__init__", "__eq__", "__hash__", "__repr__", "__str__", ] hasOnlyConstructorMethod = True for bodyElement in node.body: if ( isinstance(bodyElement, ast.FunctionDef) and bodyElement.name not in dataclassFunctions ): hasOnlyConstructorMethod = False break if ( hasOnlyConstructorMethod and sum(1 for el in node.body if isinstance(el, ast.FunctionDef)) > 0 ): self.__error(node.lineno - 1, node.col_offset, "Y119", node.name) # # eflag: noqa = M891