eric6/Plugins/CheckerPlugins/CodeStyleChecker/MiscellaneousChecker.py

changeset 7045
c2bf08f87a1d
parent 7042
2be5b245e1b8
child 7057
0e8d3b0c4889
diff -r 03fca68ab2c1 -r c2bf08f87a1d eric6/Plugins/CheckerPlugins/CodeStyleChecker/MiscellaneousChecker.py
--- a/eric6/Plugins/CheckerPlugins/CodeStyleChecker/MiscellaneousChecker.py	Thu Jun 27 19:28:15 2019 +0200
+++ b/eric6/Plugins/CheckerPlugins/CodeStyleChecker/MiscellaneousChecker.py	Fri Jun 28 17:51:14 2019 +0200
@@ -54,6 +54,11 @@
         ## Dictionaries with sorted keys
         "M201",
         
+        ## Naive datetime usage
+        "M301", "M302", "M303", "M304", "M305", "M306", "M307", "M308",
+        "M311", "M312", "M313", "M314", "M315",
+        "M321",
+        
         ## Bugbear
         "M501", "M502", "M503", "M504", "M505", "M506", "M507", "M508",
         "M509",
@@ -176,6 +181,9 @@
             (self.__checkReturn, ("M831", "M832", "M833", "M834")),
             (self.__checkLineContinuation, ("M841",)),
             (self.__checkCommentedCode, ("M891",)),
+            (self.__checkDateTime, ("M301", "M302", "M303", "M304", "M305",
+                                    "M306", "M307", "M308", "M311", "M312",
+                                    "M313", "M314", "M315", "M321")),
         ]
         
         self.__defaultArgs = {
@@ -255,6 +263,23 @@
         self.__error(offset[0] - 1, offset[1] or 0,
                      'M901', exc_type.__name__, exc.args[0])
     
+    def __generateTree(self):
+        """
+        Private method to generate an AST for our source.
+        
+        @return generated AST
+        @rtype ast.AST
+        """
+        source = "".join(self.__source)
+        # Check type for py2: if not str it's unicode
+        if sys.version_info[0] == 2:
+            try:
+                source = source.encode('utf-8')
+            except UnicodeError:
+                pass
+        
+        return compile(source, self.__filename, 'exec', ast.PyCF_ONLY_AST)
+    
     def run(self):
         """
         Public method to check the given source against miscellaneous
@@ -268,16 +293,8 @@
             # don't do anything, if no codes were selected
             return
         
-        source = "".join(self.__source)
-        # Check type for py2: if not str it's unicode
-        if sys.version_info[0] == 2:
-            try:
-                source = source.encode('utf-8')
-            except UnicodeError:
-                pass
         try:
-            self.__tree = compile(source, self.__filename, 'exec',
-                                  ast.PyCF_ONLY_AST)
+            self.__tree = self.__generateTree()
         except (SyntaxError, TypeError):
             self.__reportInvalidSyntax()
             return
@@ -371,7 +388,7 @@
     
     def __checkLineContinuation(self):
         """
-        Private method to check öine continuation using '\'.
+        Private method to check line continuation using backslash.
         """
         # generate source lines without comments
         linesIterator = iter(self.__source)
@@ -832,6 +849,28 @@
             node = violation[0]
             reason = violation[1]
             self.__error(node.lineno - 1, node.col_offset, reason)
+    
+    def __checkDateTime(self):
+        """
+        Private method to check use of naive datetime functions.
+        """
+        if sys.version_info[0] == 3:
+            # this check is only performed for Python 3
+            
+            # step 1: generate an augmented node tree containing parent info
+            #         for each child node
+            tree = self.__generateTree()
+            for node in ast.walk(tree):
+                for childNode in ast.iter_child_nodes(node):
+                    childNode._dtCheckerParent = node
+            
+            # step 2: perform checks and report issues
+            visitor = DateTimeVisitor()
+            visitor.visit(tree)
+            for violation in visitor.violations:
+                node = violation[0]
+                reason = violation[1]
+                self.__error(node.lineno - 1, node.col_offset, reason)
 
 
 class TextVisitor(ast.NodeVisitor):
@@ -1735,5 +1774,193 @@
                 return True
         
         return False
+
+
+class DateTimeVisitor(ast.NodeVisitor):
+    """
+    Class implementing a node visitor to check datetime function calls.
+    
+    Note: This class is modelled after flake8_datetimez checker.
+    """
+    def __init__(self):
+        """
+        Constructor
+        """
+        super(DateTimeVisitor, self).__init__()
+        
+        self.violations = []
+    
+    def __getFromKeywords(self, keywords, name):
+        """
+        Private method to get a keyword node given its name.
+        
+        @param keywords list of keyword argument nodes
+        @type list of ast.AST
+        @param name name of the keyword node
+        @type str
+        @return keyword node
+        @rtype ast.AST
+        """
+        for keyword in keywords:
+            if keyword.arg == name:
+                return keyword
+        
+        return None
+    
+    def visit_Call(self, node):
+        """
+        Public method to handle a function call.
+
+        Every datetime related function call is check for use of the naive
+        variant (i.e. use without TZ info).
+        
+        @param node reference to the node to be processed
+        @type ast.Call
+        """
+        # datetime.something()
+        isDateTimeClass = (
+            isinstance(node.func, ast.Attribute) and
+            isinstance(node.func.value, ast.Name) and
+            node.func.value.id == 'datetime')
+        
+        # datetime.datetime.something()
+        isDateTimeModuleAndClass = (
+            isinstance(node.func, ast.Attribute) and
+            isinstance(node.func.value, ast.Attribute) and
+            node.func.value.attr == 'datetime' and
+            isinstance(node.func.value.value, ast.Name) and
+            node.func.value.value.id == 'datetime')
+        
+        if isDateTimeClass:
+            if node.func.attr == 'datetime':
+                # datetime.datetime(2000, 1, 1, 0, 0, 0, 0,
+                #                   datetime.timezone.utc)
+                isCase1 = (len(node.args) >= 8 and
+                           not (isinstance(node.args[7], ast.NameConstant) and
+                                node.args[7].value is None))
+                
+                # datetime.datetime(2000, 1, 1, tzinfo=datetime.timezone.utc)
+                tzinfoKeyword = self.__getFromKeywords(node.keywords, 'tzinfo')
+                isCase2 = (tzinfoKeyword is not None and
+                           not (isinstance(tzinfoKeyword.value,
+                                           ast.NameConstant) and
+                                tzinfoKeyword.value.value is None))
+                
+                if not (isCase1 or isCase2):
+                    self.violations.append((node, "M301"))
+            
+            elif node.func.attr == 'time':
+                # time(12, 10, 45, 0, datetime.timezone.utc)
+                isCase1 = (len(node.args) >= 5 and
+                           not (isinstance(node.args[4], ast.NameConstant) and
+                                node.args[4].value is None))
+                
+                # datetime.time(12, 10, 45, tzinfo=datetime.timezone.utc)
+                tzinfoKeyword = self.__getFromKeywords(node.keywords, 'tzinfo')
+                isCase2 = (tzinfoKeyword is not None and
+                           not (isinstance(tzinfoKeyword.value,
+                                           ast.NameConstant) and
+                                tzinfoKeyword.value.value is None))
+                
+                if not (isCase1 or isCase2):
+                    self.violations.append((node, "M321"))
+            
+            elif node.func.attr == 'date':
+                self.violations.append((node, "M311"))
+        
+        if isDateTimeClass or isDateTimeModuleAndClass:
+            if node.func.attr == 'today':
+                self.violations.append((node, "M302"))
+            
+            elif node.func.attr == 'utcnow':
+                self.violations.append((node, "M303"))
+            
+            elif node.func.attr == 'utcfromtimestamp':
+                self.violations.append((node, "M304"))
+            
+            elif node.func.attr in 'now':
+                # datetime.now(UTC)
+                isCase1 = (len(node.args) == 1 and
+                           len(node.keywords) == 0 and
+                           not (isinstance(node.args[0], ast.NameConstant) and
+                                node.args[0].value is None))
+                
+                # datetime.now(tz=UTC)
+                tzKeyword = self.__getFromKeywords(node.keywords, 'tz')
+                isCase2 = (tzKeyword is not None and
+                           not (isinstance(tzKeyword.value,
+                                           ast.NameConstant) and
+                                tzKeyword.value.value is None))
+                
+                if not (isCase1 or isCase2):
+                    self.violations.append((node, "M305"))
+            
+            elif node.func.attr == 'fromtimestamp':
+                # datetime.fromtimestamp(1234, UTC)
+                isCase1 = (len(node.args) == 2 and
+                           len(node.keywords) == 0 and
+                           not (isinstance(node.args[1], ast.NameConstant) and
+                                node.args[1].value is None))
+                
+                # datetime.fromtimestamp(1234, tz=UTC)
+                tzKeyword = self.__getFromKeywords(node.keywords, 'tz')
+                isCase2 = (tzKeyword is not None and
+                           not (isinstance(tzKeyword.value,
+                                           ast.NameConstant) and
+                                tzKeyword.value.value is None))
+                
+                if not (isCase1 or isCase2):
+                    self.violations.append((node, "M306"))
+            
+            elif node.func.attr == 'strptime':
+                # datetime.strptime(...).replace(tzinfo=UTC)
+                parent = getattr(node, '_dtCheckerParent', None)
+                pparent = getattr(parent, '_dtCheckerParent', None)
+                if not (isinstance(parent, ast.Attribute) and
+                        parent.attr == 'replace'):
+                    isCase1 = False
+                elif not isinstance(pparent, ast.Call):
+                    isCase1 = False
+                else:
+                    tzinfoKeyword = self.__getFromKeywords(pparent.keywords,
+                                                           'tzinfo')
+                    isCase1 = (tzinfoKeyword is not None and
+                               not (isinstance(tzinfoKeyword.value,
+                                               ast.NameConstant) and
+                                    tzinfoKeyword.value.value is None))
+                
+                if not isCase1:
+                    self.violations.append((node, "M307"))
+            
+            elif node.func.attr == 'fromordinal':
+                self.violations.append((node, "M308"))
+        
+        # date.something()
+        isDateClass = (isinstance(node.func, ast.Attribute) and
+                       isinstance(node.func.value, ast.Name) and
+                       node.func.value.id == 'date')
+        
+        # datetime.date.something()
+        isDateModuleAndClass = (isinstance(node.func, ast.Attribute) and
+                                isinstance(node.func.value, ast.Attribute) and
+                                node.func.value.attr == 'date' and
+                                isinstance(node.func.value.value, ast.Name) and
+                                node.func.value.value.id == 'datetime')
+        
+        if isDateClass or isDateModuleAndClass:
+            if node.func.attr == 'today':
+                self.violations.append((node, "M312"))
+            
+            elif node.func.attr == 'fromtimestamp':
+                self.violations.append((node, "M313"))
+            
+            elif node.func.attr == 'fromordinal':
+                self.violations.append((node, "M314"))
+            
+            elif node.func.attr == 'fromisoformat':
+                self.violations.append((node, "M315"))
+        
+        self.generic_visit(node)
+
 #
 # eflag: noqa = M702

eric ide

mercurial