|
1 # -*- coding: utf-8 -*- |
|
2 |
|
3 # Copyright (c) 2023 Detlev Offenbach <detlev@die-offenbachs.de> |
|
4 # |
|
5 |
|
6 """ |
|
7 Module implementing a node visitor for checking the import of typing.Union. |
|
8 """ |
|
9 |
|
10 # |
|
11 # The visitor is adapted from flake8-pep604 v1.1.0 |
|
12 # |
|
13 |
|
14 import ast |
|
15 |
|
16 |
|
17 class AnnotationsUnionVisitor(ast.NodeVisitor): |
|
18 """ |
|
19 Class implementing a node visitor for checking the import of typing.Union. |
|
20 """ |
|
21 |
|
22 ModuleName = "typing" |
|
23 AttributeName = "Union" |
|
24 FullName = "typing.Union" |
|
25 |
|
26 def __init__(self): |
|
27 """ |
|
28 Constructor |
|
29 """ |
|
30 self.__unionImports = [] |
|
31 self.__aliasedUnionImports = set() |
|
32 |
|
33 def visit_Import(self, node): |
|
34 """ |
|
35 Public method to handle an ast.Import node. |
|
36 |
|
37 @param node reference to the node to be handled |
|
38 @type ast.Import |
|
39 """ |
|
40 for name in node.names: |
|
41 if name.name == self.FullName: |
|
42 self.__unionImports.append(node) |
|
43 elif name.name == self.ModuleName and name.asname: |
|
44 self.__aliasedUnionImports.add(name.asname) |
|
45 |
|
46 self.generic_visit(node) |
|
47 |
|
48 def visit_ImportFrom(self, node): |
|
49 """ |
|
50 Public method to handle an ast.ImportFrom node. |
|
51 |
|
52 @param node reference to the node to be handled |
|
53 @type ast.ImportFrom |
|
54 """ |
|
55 if node.module == self.ModuleName: |
|
56 for name in node.names: |
|
57 if name.name == self.AttributeName: |
|
58 self.__unionImports.append(node) |
|
59 if name.asname: |
|
60 self.__aliasedUnionImports.add(name.asname) |
|
61 |
|
62 self.generic_visit(node) |
|
63 |
|
64 def visit_Attribute(self, node): |
|
65 """ |
|
66 Public method to handle an ast.Attribute node. |
|
67 |
|
68 @param node reference to the node to be handled |
|
69 @type ast.Attribute |
|
70 """ |
|
71 if ( |
|
72 isinstance(node.value, ast.Name) |
|
73 and ( |
|
74 node.value.id in self.__aliasedUnionImports |
|
75 or node.value.id == self.ModuleName |
|
76 ) |
|
77 and node.attr == self.AttributeName |
|
78 ): |
|
79 self.__unionImports.append(node) |
|
80 |
|
81 self.generic_visit(node) |
|
82 |
|
83 def visit_Subscript(self, node): |
|
84 """ |
|
85 Public method to handle an ast.Subscript node. |
|
86 |
|
87 @param node reference to the node to be handled |
|
88 @type ast.Subscript |
|
89 """ |
|
90 if isinstance(node.value, ast.Name) and ( |
|
91 node.value.id in self.__aliasedUnionImports |
|
92 or node.value.id == self.AttributeName |
|
93 ): |
|
94 self.__unionImports.append(node) |
|
95 |
|
96 self.generic_visit(node) |
|
97 |
|
98 def getIssues(self): |
|
99 """ |
|
100 Public method to get the collected Union nodes. |
|
101 |
|
102 @return list of collected nodes |
|
103 @rtype list of ast.AST |
|
104 """ |
|
105 return self.__unionImports |