|
1 # -*- coding: utf-8 -*- |
|
2 |
|
3 # Copyright (c) 2024 Detlev Offenbach <detlev@die-offenbachs.de> |
|
4 # |
|
5 |
|
6 """ |
|
7 Module implementing checks for the use of 'torch.load' and 'torch.save'. |
|
8 """ |
|
9 |
|
10 # |
|
11 # This is a modified version of the one found in the bandit package. |
|
12 # |
|
13 # Original Copyright (c) 2024 Stacklok, Inc. |
|
14 # |
|
15 # SPDX-License-Identifier: Apache-2.0 |
|
16 # |
|
17 |
|
18 |
|
19 def getChecks(): |
|
20 """ |
|
21 Public method to get a dictionary with checks handled by this module. |
|
22 |
|
23 @return dictionary containing checker lists containing checker function and |
|
24 list of codes |
|
25 @rtype dict |
|
26 """ |
|
27 return { |
|
28 "Call": [ |
|
29 (checkPytorchLoadSave, ("S614",)), |
|
30 ], |
|
31 } |
|
32 |
|
33 |
|
34 def checkPytorchLoadSave(reportError, context, _config): |
|
35 """ |
|
36 Function to check for the use of 'torch.load' and 'torch.save'. |
|
37 |
|
38 Using `torch.load` with untrusted data can lead to arbitrary code |
|
39 execution, and improper use of `torch.save` might expose sensitive |
|
40 data or lead to data corruption. |
|
41 |
|
42 @param reportError function to be used to report errors |
|
43 @type func |
|
44 @param context security context object |
|
45 @type SecurityContext |
|
46 @param _config dictionary with configuration data (unused) |
|
47 @type dict |
|
48 """ |
|
49 imported = context.isModuleImportedExact("torch") |
|
50 qualname = context.callFunctionNameQual |
|
51 if not imported and isinstance(qualname, str): |
|
52 return |
|
53 |
|
54 qualnameList = qualname.split(".") |
|
55 func = qualnameList[-1] |
|
56 if all( |
|
57 [ |
|
58 "torch" in qualnameList, |
|
59 func in ["load", "save"], |
|
60 not context.checkCallArgValue("map_location", "cpu"), |
|
61 ] |
|
62 ): |
|
63 reportError( |
|
64 context.node.lineno - 1, |
|
65 context.node.col_offset, |
|
66 "S614", |
|
67 "M", |
|
68 "H", |
|
69 ) |