|
1 # -*- coding: utf-8 -*- |
|
2 |
|
3 # Copyright (c) 2025 Detlev Offenbach <detlev@die-offenbachs.de> |
|
4 # |
|
5 |
|
6 """ |
|
7 Module implementing utility functions for the PydanticVisitor class. |
|
8 """ |
|
9 |
|
10 import ast |
|
11 |
|
12 ####################################################################### |
|
13 ## adapted from: flake8-pydantic v0.4.0 |
|
14 ## |
|
15 ## Original: Copyright (c) 2023 Victorien |
|
16 ####################################################################### |
|
17 |
|
18 |
|
19 def getDecoratorNames(decoratorList): |
|
20 """ |
|
21 Function to extract the set of decorator names. |
|
22 |
|
23 @param decoratorList list of decorators to be processed |
|
24 @type list of ast.expr |
|
25 @return set containing the decorator names |
|
26 @rtype set of str |
|
27 """ |
|
28 names = set() |
|
29 |
|
30 for dec in decoratorList: |
|
31 if isinstance(dec, ast.Call): |
|
32 names.add( |
|
33 dec.func.attr if isinstance(dec.func, ast.Attribute) else dec.func.id |
|
34 ) |
|
35 elif isinstance(dec, ast.Name): |
|
36 names.add(dec.id) |
|
37 elif isinstance(dec, ast.Attribute): |
|
38 names.add(dec.attr) |
|
39 |
|
40 return names |
|
41 |
|
42 |
|
43 def _hasPydanticModelBase(node, *, includeRootModel): |
|
44 """ |
|
45 Function to check, if a class definition inherits from Pydantic model classes. |
|
46 |
|
47 @param node reference to the node to be be analyzed |
|
48 @type ast.ClassDef |
|
49 @keyparam includeRootModel flag indicating to include the root model |
|
50 @type bool |
|
51 @return flag indicating that the class definition inherits from a Pydantic model |
|
52 class |
|
53 @rtype bool |
|
54 """ |
|
55 modelClassNames = {"BaseModel"} |
|
56 if includeRootModel: |
|
57 modelClassNames.add("RootModel") |
|
58 |
|
59 for base in node.bases: |
|
60 if isinstance(base, ast.Name) and base.id in modelClassNames: |
|
61 return True |
|
62 if isinstance(base, ast.Attribute) and base.attr in modelClassNames: |
|
63 return True |
|
64 return False |
|
65 |
|
66 |
|
67 def _hasModelConfig(node): |
|
68 """ |
|
69 Function to check, if the class has a `model_config` attribute set. |
|
70 |
|
71 @param node reference to the node to be be analyzed |
|
72 @type ast.ClassDef |
|
73 @return flag indicating that the class has a `model_config` attribute set |
|
74 @rtype bool |
|
75 """ |
|
76 for stmt in node.body: |
|
77 if ( |
|
78 isinstance(stmt, ast.AnnAssign) |
|
79 and isinstance(stmt.target, ast.Name) |
|
80 and stmt.target.id == "model_config" |
|
81 ): |
|
82 ##~ model_config: ... = ... |
|
83 return True |
|
84 |
|
85 if isinstance(stmt, ast.Assign) and any( |
|
86 t.id == "model_config" for t in stmt.targets if isinstance(t, ast.Name) |
|
87 ): |
|
88 ##~ model_config = ... |
|
89 return True |
|
90 |
|
91 return False |
|
92 |
|
93 |
|
94 PYDANTIC_FIELD_ARGUMENTS = { |
|
95 "default", |
|
96 "default_factory", |
|
97 "alias", |
|
98 "alias_priority", |
|
99 "validation_alias", |
|
100 "title", |
|
101 "description", |
|
102 "examples", |
|
103 "exclude", |
|
104 "discriminator", |
|
105 "json_schema_extra", |
|
106 "frozen", |
|
107 "validate_default", |
|
108 "repr", |
|
109 "init", |
|
110 "init_var", |
|
111 "kw_only", |
|
112 "pattern", |
|
113 "strict", |
|
114 "gt", |
|
115 "ge", |
|
116 "lt", |
|
117 "le", |
|
118 "multiple_of", |
|
119 "allow_inf_nan", |
|
120 "max_digits", |
|
121 "decimal_places", |
|
122 "min_length", |
|
123 "max_length", |
|
124 "union_mode", |
|
125 } |
|
126 |
|
127 |
|
128 def _hasFieldFunction(node): |
|
129 """ |
|
130 Function to check, if the class has a field defined with the `Field` function. |
|
131 |
|
132 @param node reference to the node to be be analyzed |
|
133 @type ast.ClassDef |
|
134 @return flag indicating that the class has a field defined with the `Field` function |
|
135 @rtype bool |
|
136 """ |
|
137 if any( |
|
138 isinstance(stmt, (ast.Assign, ast.AnnAssign)) |
|
139 and isinstance(stmt.value, ast.Call) |
|
140 and ( |
|
141 ( |
|
142 isinstance(stmt.value.func, ast.Name) and stmt.value.func.id == "Field" |
|
143 ) # f = Field(...) |
|
144 or ( |
|
145 isinstance(stmt.value.func, ast.Attribute) |
|
146 and stmt.value.func.attr == "Field" |
|
147 ) # f = pydantic.Field(...) |
|
148 ) |
|
149 and all( |
|
150 kw.arg in PYDANTIC_FIELD_ARGUMENTS |
|
151 for kw in stmt.value.keywords |
|
152 if kw.arg is not None |
|
153 ) |
|
154 for stmt in node.body |
|
155 ): |
|
156 return True |
|
157 |
|
158 return False |
|
159 |
|
160 |
|
161 def _hasAnnotatedField(node): |
|
162 """ |
|
163 Function to check if the class has a field making use of `Annotated`. |
|
164 |
|
165 @param node reference to the node to be be analyzed |
|
166 @type ast.ClassDef |
|
167 @return flag indicating that the class has a field making use of `Annotated` |
|
168 @rtype bool |
|
169 """ |
|
170 for stmt in node.body: |
|
171 if isinstance(stmt, ast.AnnAssign) and isinstance( |
|
172 stmt.annotation, ast.Subscript |
|
173 ): |
|
174 if ( |
|
175 isinstance(stmt.annotation.value, ast.Name) |
|
176 and stmt.annotation.value.id == "Annotated" |
|
177 ): |
|
178 ##~ f: Annotated[...] |
|
179 return True |
|
180 |
|
181 if ( |
|
182 isinstance(stmt.annotation.value, ast.Attribute) |
|
183 and stmt.annotation.value.attr == "Annotated" |
|
184 ): |
|
185 ##~ f: typing.Annotated[...] |
|
186 return True |
|
187 |
|
188 return False |
|
189 |
|
190 |
|
191 PYDANTIC_DECORATORS = { |
|
192 "computed_field", |
|
193 "field_serializer", |
|
194 "model_serializer", |
|
195 "field_validator", |
|
196 "model_validator", |
|
197 } |
|
198 |
|
199 |
|
200 def _hasPydanticDecorator(node): |
|
201 """ |
|
202 Function to check, if the class makes use of Pydantic decorators, such as |
|
203 `computed_field` or `model_validator`. |
|
204 |
|
205 @param node reference to the node to be be analyzed |
|
206 @type ast.ClassDef |
|
207 @return flag indicating that the class makes use of Pydantic decorators, such as |
|
208 `computed_field` or `model_validator`. |
|
209 @rtype bool |
|
210 """ |
|
211 for stmt in node.body: |
|
212 if isinstance(stmt, ast.FunctionDef): |
|
213 decoratorNames = getDecoratorNames(stmt.decorator_list) |
|
214 if PYDANTIC_DECORATORS & decoratorNames: |
|
215 return True |
|
216 return False |
|
217 |
|
218 |
|
219 PYDANTIC_METHODS = { |
|
220 "model_construct", |
|
221 "model_copy", |
|
222 "model_dump", |
|
223 "model_dump_json", |
|
224 "model_json_schema", |
|
225 "model_parametrized_name", |
|
226 "model_rebuild", |
|
227 "model_validate", |
|
228 "model_validate_json", |
|
229 "model_validate_strings", |
|
230 } |
|
231 |
|
232 |
|
233 def _hasPydanticMethod(node: ast.ClassDef) -> bool: |
|
234 """ |
|
235 Function to check, if the class overrides any of the Pydantic methods, such as |
|
236 `model_dump`. |
|
237 |
|
238 @param node reference to the node to be be analyzed |
|
239 @type ast.ClassDef |
|
240 @return flag indicating that class overrides any of the Pydantic methods, such as |
|
241 `model_dump` |
|
242 @rtype bool |
|
243 """ |
|
244 if any( |
|
245 isinstance(stmt, ast.FunctionDef) |
|
246 and ( |
|
247 stmt.name.startswith(("__pydantic_", "__get_pydantic_")) |
|
248 or stmt.name in PYDANTIC_METHODS |
|
249 ) |
|
250 for stmt in node.body |
|
251 ): |
|
252 return True |
|
253 |
|
254 return False |
|
255 |
|
256 |
|
257 def isPydanticModel(node, *, includeRootModel=True): |
|
258 """ |
|
259 Function to determine if a class definition is a Pydantic model. |
|
260 |
|
261 Multiple heuristics are use to determine if this is the case: |
|
262 - The class inherits from `BaseModel` (or `RootModel` if `includeRootModel` is |
|
263 `True`). |
|
264 - The class has a `model_config` attribute set. |
|
265 - The class has a field defined with the `Field` function. |
|
266 - The class has a field making use of `Annotated`. |
|
267 - The class makes use of Pydantic decorators, such as `computed_field` or |
|
268 `model_validator`. |
|
269 - The class overrides any of the Pydantic methods, such as `model_dump`. |
|
270 |
|
271 @param node reference to the node to be be analyzed |
|
272 @type ast.ClassDef |
|
273 @keyparam includeRootModel flag indicating to include the root model |
|
274 (defaults to True) |
|
275 @type bool (optional) |
|
276 @return flag indicating a Pydantic model class |
|
277 @rtype bool |
|
278 """ |
|
279 if not node.bases: |
|
280 return False |
|
281 |
|
282 return ( |
|
283 _hasPydanticModelBase(node, includeRootModel=includeRootModel) |
|
284 or _hasModelConfig(node) |
|
285 or _hasFieldFunction(node) |
|
286 or _hasAnnotatedField(node) |
|
287 or _hasPydanticDecorator(node) |
|
288 or _hasPydanticMethod(node) |
|
289 ) |
|
290 |
|
291 |
|
292 def isDataclass(node): |
|
293 """ |
|
294 Function to check, if a class is a dataclass. |
|
295 |
|
296 @param node reference to the node to be be analyzed |
|
297 @type ast.ClassDef |
|
298 @return flag indicating that the class is a dataclass. |
|
299 @rtype bool |
|
300 """ |
|
301 """Determine if a class is a dataclass.""" |
|
302 |
|
303 return bool( |
|
304 {"dataclass", "pydantic_dataclass"} & getDecoratorNames(node.decorator_list) |
|
305 ) |
|
306 |
|
307 |
|
308 def isFunction(node, functionName): |
|
309 """ |
|
310 Function to check, if a function call is referencing a given function name. |
|
311 |
|
312 @param node reference to the node to be be analyzed |
|
313 @type ast.Call |
|
314 @param functionName name of the function to check for |
|
315 @type str |
|
316 @return flag indicating that the function call is referencing the given function |
|
317 name |
|
318 @rtype bool |
|
319 """ |
|
320 return (isinstance(node.func, ast.Name) and node.func.id == functionName) or ( |
|
321 isinstance(node.func, ast.Attribute) and node.func.attr == functionName |
|
322 ) |
|
323 |
|
324 |
|
325 def isName(node, name): |
|
326 """ |
|
327 Function to check, if an expression is referencing a given name. |
|
328 |
|
329 @param node reference to the node to be be analyzed |
|
330 @type ast.expr |
|
331 @param name name to check for |
|
332 @type str |
|
333 @return flag indicating that the expression is referencing teh given name |
|
334 @rtype bool |
|
335 """ |
|
336 return (isinstance(node, ast.Name) and node.id == name) or ( |
|
337 isinstance(node, ast.Attribute) and node.attr == name |
|
338 ) |
|
339 |
|
340 |
|
341 def extractAnnotations(node): |
|
342 """ |
|
343 Function to extract the annotations of an expression. |
|
344 |
|
345 @param node reference to the node to be be processed |
|
346 @type ast.expr |
|
347 @return set containing the annotation names |
|
348 @rtype set[str] |
|
349 """ |
|
350 annotations = set() |
|
351 |
|
352 if isinstance(node, ast.Name): |
|
353 ##~ foo: date = ... |
|
354 annotations.add(node.id) |
|
355 |
|
356 elif isinstance(node, ast.BinOp): |
|
357 ##~ foo: date | None = ... |
|
358 annotations |= extractAnnotations(node.left) |
|
359 annotations |= extractAnnotations(node.right) |
|
360 |
|
361 elif isinstance(node, ast.Subscript): |
|
362 ##~ foo: dict[str, date] |
|
363 ##~ foo: Annotated[list[date], ...] |
|
364 if isinstance(node.slice, ast.Tuple): |
|
365 for elt in node.slice.elts: |
|
366 annotations |= extractAnnotations(elt) |
|
367 if isinstance(node.slice, ast.Name): |
|
368 annotations.add(node.slice.id) |
|
369 |
|
370 return annotations |