eric7/Plugins/CheckerPlugins/CodeStyleChecker/Imports/LocalImportVisitor.py

Thu, 02 Dec 2021 18:53:26 +0100

author
Detlev Offenbach <detlev@die-offenbachs.de>
date
Thu, 02 Dec 2021 18:53:26 +0100
branch
eric7
changeset 8802
129a973fc33e
parent 8801
8fbb21be8579
child 8881
54e42bc2437a
permissions
-rw-r--r--

Continued implementing a checker for import statements (import order).

# -*- coding: utf-8 -*-

# Copyright (c) 2021 Detlev Offenbach <detlev@die-offenbachs.de>
#

"""
Module implementing a node visitor for checking local import statements.
"""

import ast

#
# The visitor is adapted from flake8-local-import v1.0.6
#


class LocalImportVisitor(ast.NodeVisitor):
    """
    Class implementing a node visitor for checking local import statements.
    """
    def __init__(self, args, checker):
        """
        Constructor
        
        @param args dictionary containing the checker arguments
        @type dict
        @param checker reference to the checker
        @type ImportsChecker
        """
        self.__appImportNames = args.get("ApplicationPackageNames", [])
        self.__checker = checker
        
        self.violations = []
    
    def visit(self, node):
        """
        Public method to traverse the tree of an AST node.
        
        @param node AST node to parse
        @type ast.AST
        """
        previous = None
        isLocal = (
            isinstance(node, ast.FunctionDef) or
            getattr(node, 'is_local', False)
        )
        for child in ast.iter_child_nodes(node):
            child.parent = node
            child.previous = previous
            child.is_local = isLocal
            previous = child
        
        super().visit(node)
    
    def visit_FunctionDef(self, node):
        """
        Public method to handle a function definition.
        
        @param node reference to the node to be processed
        @type ast.FunctionDef
        """
        children = list(ast.iter_child_nodes(node))
        if len(children) > 1:
            firstStatement = children[1]
            
            if isinstance(firstStatement, ast.Expr):
                value = getattr(firstStatement, 'value', None)
                if isinstance(value, ast.Constant):
                    firstStatement.is_doc_str = True
        
        self.generic_visit(node)
    
    def visit_Import(self, node):
        """
        Public method to handle an import statement.
        
        @param node reference to the node to be processed
        @type ast.Import
        """
        if not getattr(node, 'is_local', False):
            self.generic_visit(node)
            return
        
        for name in node.names:
            self.__assertExternalModule(node, name.name or '')
        
        self.__visitImportNode(node)
    
    def visit_ImportFrom(self, node):
        """
        Public method to handle an import from statement.
        
        @param node reference to the node to be processed
        @type ast.ImportFrom
        """
        if not getattr(node, 'is_local', False):
            self.generic_visit(node)
            return
        
        self.__assertExternalModule(node, node.module or '')
        
        self.__visitImportNode(node)
    
    def __visitImportNode(self, node):
        """
        Private method to handle an import or import from statement.
        
        @param node reference to the node to be processed
        @type ast.Import or ast.ImportFrom
        """
        parent = getattr(node, 'parent', None)
        if isinstance(parent, ast.Module):
            self.generic_visit(node)
            return
        
        previous = getattr(node, 'previous', None)
        
        isAllowedPrevious = (
            (isinstance(previous, ast.Expr) and
             getattr(previous, 'is_doc_str', False)) or
            isinstance(previous, (ast.Import, ast.ImportFrom, ast.arguments))
        )
        
        if not isinstance(parent, ast.FunctionDef) or not isAllowedPrevious:
            self.violations.append((node, "I101"))
        
        self.generic_visit(node)
    
    def __assertExternalModule(self, node, module):
        """
        Private method to assert the given node.
        
        @param node reference to the node to be processed
        @type ast.stmt
        @param module name of the module
        @type str
        """
        parent = getattr(node, 'parent', None)
        if isinstance(parent, ast.Module):
            return
        
        modulePrefix = module + '.'
        
        if (
            getattr(node, 'level', 0) != 0 or
            any(modulePrefix.startswith(appModule + '.')
                for appModule in self.__appImportNames)
        ):
            return
        
        if module.split('.')[0] not in self.__checker.getStandardModules():
            self.violations.append((node, "I102"))
        else:
            self.violations.append((node, "I103"))

eric ide

mercurial