|
1 # Copyright 2016 Grist Labs, Inc. |
|
2 # |
|
3 # Licensed under the Apache License, Version 2.0 (the "License"); |
|
4 # you may not use this file except in compliance with the License. |
|
5 # You may obtain a copy of the License at |
|
6 # |
|
7 # http://www.apache.org/licenses/LICENSE-2.0 |
|
8 # |
|
9 # Unless required by applicable law or agreed to in writing, software |
|
10 # distributed under the License is distributed on an "AS IS" BASIS, |
|
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
12 # See the License for the specific language governing permissions and |
|
13 # limitations under the License. |
|
14 |
|
15 import ast |
|
16 import collections |
|
17 import token |
|
18 from six import iteritems |
|
19 |
|
20 |
|
21 def token_repr(tok_type, string): |
|
22 """Returns a human-friendly representation of a token with the given type and string.""" |
|
23 # repr() prefixes unicode with 'u' on Python2 but not Python3; strip it out for consistency. |
|
24 return '%s:%s' % (token.tok_name[tok_type], repr(string).lstrip('u')) |
|
25 |
|
26 |
|
27 class Token(collections.namedtuple('Token', 'type string start end line index startpos endpos')): |
|
28 """ |
|
29 TokenInfo is an 8-tuple containing the same 5 fields as the tokens produced by the tokenize |
|
30 module, and 3 additional ones useful for this module: |
|
31 |
|
32 - [0] .type Token type (see token.py) |
|
33 - [1] .string Token (a string) |
|
34 - [2] .start Starting (row, column) indices of the token (a 2-tuple of ints) |
|
35 - [3] .end Ending (row, column) indices of the token (a 2-tuple of ints) |
|
36 - [4] .line Original line (string) |
|
37 - [5] .index Index of the token in the list of tokens that it belongs to. |
|
38 - [6] .startpos Starting character offset into the input text. |
|
39 - [7] .endpos Ending character offset into the input text. |
|
40 """ |
|
41 def __str__(self): |
|
42 return token_repr(self.type, self.string) |
|
43 |
|
44 |
|
45 def match_token(token, tok_type, tok_str=None): |
|
46 """Returns true if token is of the given type and, if a string is given, has that string.""" |
|
47 return token.type == tok_type and (tok_str is None or token.string == tok_str) |
|
48 |
|
49 |
|
50 def expect_token(token, tok_type, tok_str=None): |
|
51 """ |
|
52 Verifies that the given token is of the expected type. If tok_str is given, the token string |
|
53 is verified too. If the token doesn't match, raises an informative ValueError. |
|
54 """ |
|
55 if not match_token(token, tok_type, tok_str): |
|
56 raise ValueError("Expected token %s, got %s on line %s col %s" % ( |
|
57 token_repr(tok_type, tok_str), str(token), |
|
58 token.start[0], token.start[1] + 1)) |
|
59 |
|
60 # These were previously defined in tokenize.py and distinguishable by being greater than |
|
61 # token.N_TOKEN. As of python3.7, they are in token.py, and we check for them explicitly. |
|
62 if hasattr(token, 'COMMENT'): |
|
63 def is_non_coding_token(token_type): |
|
64 """ |
|
65 These are considered non-coding tokens, as they don't affect the syntax tree. |
|
66 """ |
|
67 return token_type in (token.NL, token.COMMENT, token.ENCODING) |
|
68 else: |
|
69 def is_non_coding_token(token_type): |
|
70 """ |
|
71 These are considered non-coding tokens, as they don't affect the syntax tree. |
|
72 """ |
|
73 return token_type >= token.N_TOKENS |
|
74 |
|
75 def iter_children(node): |
|
76 """ |
|
77 Yields all direct children of a AST node, skipping children that are singleton nodes. |
|
78 """ |
|
79 return iter_children_astroid(node) if hasattr(node, 'get_children') else iter_children_ast(node) |
|
80 |
|
81 |
|
82 def iter_children_func(node): |
|
83 """ |
|
84 Returns a slightly more optimized function to use in place of ``iter_children``, depending on |
|
85 whether ``node`` is from ``ast`` or from the ``astroid`` module. |
|
86 """ |
|
87 return iter_children_astroid if hasattr(node, 'get_children') else iter_children_ast |
|
88 |
|
89 |
|
90 def iter_children_astroid(node): |
|
91 # Don't attempt to process children of JoinedStr nodes, which we can't fully handle yet. |
|
92 if is_joined_str(node): |
|
93 return [] |
|
94 |
|
95 return node.get_children() |
|
96 |
|
97 |
|
98 SINGLETONS = {c for n, c in iteritems(ast.__dict__) if isinstance(c, type) and |
|
99 issubclass(c, (ast.expr_context, ast.boolop, ast.operator, ast.unaryop, ast.cmpop))} |
|
100 |
|
101 def iter_children_ast(node): |
|
102 # Don't attempt to process children of JoinedStr nodes, which we can't fully handle yet. |
|
103 if is_joined_str(node): |
|
104 return |
|
105 |
|
106 for child in ast.iter_child_nodes(node): |
|
107 # Skip singleton children; they don't reflect particular positions in the code and break the |
|
108 # assumptions about the tree consisting of distinct nodes. Note that collecting classes |
|
109 # beforehand and checking them in a set is faster than using isinstance each time. |
|
110 if child.__class__ not in SINGLETONS: |
|
111 yield child |
|
112 |
|
113 |
|
114 stmt_class_names = {n for n, c in iteritems(ast.__dict__) |
|
115 if isinstance(c, type) and issubclass(c, ast.stmt)} |
|
116 expr_class_names = ({n for n, c in iteritems(ast.__dict__) |
|
117 if isinstance(c, type) and issubclass(c, ast.expr)} | |
|
118 {'AssignName', 'DelName', 'Const', 'AssignAttr', 'DelAttr'}) |
|
119 |
|
120 # These feel hacky compared to isinstance() but allow us to work with both ast and astroid nodes |
|
121 # in the same way, and without even importing astroid. |
|
122 def is_expr(node): |
|
123 """Returns whether node is an expression node.""" |
|
124 return node.__class__.__name__ in expr_class_names |
|
125 |
|
126 def is_stmt(node): |
|
127 """Returns whether node is a statement node.""" |
|
128 return node.__class__.__name__ in stmt_class_names |
|
129 |
|
130 def is_module(node): |
|
131 """Returns whether node is a module node.""" |
|
132 return node.__class__.__name__ == 'Module' |
|
133 |
|
134 def is_joined_str(node): |
|
135 """Returns whether node is a JoinedStr node, used to represent f-strings.""" |
|
136 # At the moment, nodes below JoinedStr have wrong line/col info, and trying to process them only |
|
137 # leads to errors. |
|
138 return node.__class__.__name__ == 'JoinedStr' |
|
139 |
|
140 |
|
141 # Sentinel value used by visit_tree(). |
|
142 _PREVISIT = object() |
|
143 |
|
144 def visit_tree(node, previsit, postvisit): |
|
145 """ |
|
146 Scans the tree under the node depth-first using an explicit stack. It avoids implicit recursion |
|
147 via the function call stack to avoid hitting 'maximum recursion depth exceeded' error. |
|
148 |
|
149 It calls ``previsit()`` and ``postvisit()`` as follows: |
|
150 |
|
151 * ``previsit(node, par_value)`` - should return ``(par_value, value)`` |
|
152 ``par_value`` is as returned from ``previsit()`` of the parent. |
|
153 |
|
154 * ``postvisit(node, par_value, value)`` - should return ``value`` |
|
155 ``par_value`` is as returned from ``previsit()`` of the parent, and ``value`` is as |
|
156 returned from ``previsit()`` of this node itself. The return ``value`` is ignored except |
|
157 the one for the root node, which is returned from the overall ``visit_tree()`` call. |
|
158 |
|
159 For the initial node, ``par_value`` is None. Either ``previsit`` and ``postvisit`` may be None. |
|
160 """ |
|
161 if not previsit: |
|
162 previsit = lambda node, pvalue: (None, None) |
|
163 if not postvisit: |
|
164 postvisit = lambda node, pvalue, value: None |
|
165 |
|
166 iter_children = iter_children_func(node) |
|
167 done = set() |
|
168 ret = None |
|
169 stack = [(node, None, _PREVISIT)] |
|
170 while stack: |
|
171 current, par_value, value = stack.pop() |
|
172 if value is _PREVISIT: |
|
173 assert current not in done # protect againt infinite loop in case of a bad tree. |
|
174 done.add(current) |
|
175 |
|
176 pvalue, post_value = previsit(current, par_value) |
|
177 stack.append((current, par_value, post_value)) |
|
178 |
|
179 # Insert all children in reverse order (so that first child ends up on top of the stack). |
|
180 ins = len(stack) |
|
181 for n in iter_children(current): |
|
182 stack.insert(ins, (n, pvalue, _PREVISIT)) |
|
183 else: |
|
184 ret = postvisit(current, par_value, value) |
|
185 return ret |
|
186 |
|
187 |
|
188 |
|
189 def walk(node): |
|
190 """ |
|
191 Recursively yield all descendant nodes in the tree starting at ``node`` (including ``node`` |
|
192 itself), using depth-first pre-order traversal (yieling parents before their children). |
|
193 |
|
194 This is similar to ``ast.walk()``, but with a different order, and it works for both ``ast`` and |
|
195 ``astroid`` trees. Also, as ``iter_children()``, it skips singleton nodes generated by ``ast``. |
|
196 """ |
|
197 iter_children = iter_children_func(node) |
|
198 done = set() |
|
199 stack = [node] |
|
200 while stack: |
|
201 current = stack.pop() |
|
202 assert current not in done # protect againt infinite loop in case of a bad tree. |
|
203 done.add(current) |
|
204 |
|
205 yield current |
|
206 |
|
207 # Insert all children in reverse order (so that first child ends up on top of the stack). |
|
208 # This is faster than building a list and reversing it. |
|
209 ins = len(stack) |
|
210 for c in iter_children(current): |
|
211 stack.insert(ins, c) |
|
212 |
|
213 |
|
214 def replace(text, replacements): |
|
215 """ |
|
216 Replaces multiple slices of text with new values. This is a convenience method for making code |
|
217 modifications of ranges e.g. as identified by ``ASTTokens.get_text_range(node)``. Replacements is |
|
218 an iterable of ``(start, end, new_text)`` tuples. |
|
219 |
|
220 For example, ``replace("this is a test", [(0, 4, "X"), (8, 1, "THE")])`` produces |
|
221 ``"X is THE test"``. |
|
222 """ |
|
223 p = 0 |
|
224 parts = [] |
|
225 for (start, end, new_text) in sorted(replacements): |
|
226 parts.append(text[p:start]) |
|
227 parts.append(new_text) |
|
228 p = end |
|
229 parts.append(text[p:]) |
|
230 return ''.join(parts) |
|
231 |
|
232 |
|
233 class NodeMethods(object): |
|
234 """ |
|
235 Helper to get `visit_{node_type}` methods given a node's class and cache the results. |
|
236 """ |
|
237 def __init__(self): |
|
238 self._cache = {} |
|
239 |
|
240 def get(self, obj, cls): |
|
241 """ |
|
242 Using the lowercase name of the class as node_type, returns `obj.visit_{node_type}`, |
|
243 or `obj.visit_default` if the type-specific method is not found. |
|
244 """ |
|
245 method = self._cache.get(cls) |
|
246 if not method: |
|
247 name = "visit_" + cls.__name__.lower() |
|
248 method = getattr(obj, name, obj.visit_default) |
|
249 self._cache[cls] = method |
|
250 return method |