src/eric7/Plugins/CheckerPlugins/CodeStyleChecker/Security/Checks/pytorchLoadSave.py

branch
eric7
changeset 10996
a3dc181d14e1
child 11090
f5f5f5803935
equal deleted inserted replaced
10995:f94a27bbf6c4 10996:a3dc181d14e1
1 # -*- coding: utf-8 -*-
2
3 # Copyright (c) 2024 Detlev Offenbach <detlev@die-offenbachs.de>
4 #
5
6 """
7 Module implementing checks for the use of 'torch.load' and 'torch.save'.
8 """
9
10 #
11 # This is a modified version of the one found in the bandit package.
12 #
13 # Original Copyright (c) 2024 Stacklok, Inc.
14 #
15 # SPDX-License-Identifier: Apache-2.0
16 #
17
18
19 def getChecks():
20 """
21 Public method to get a dictionary with checks handled by this module.
22
23 @return dictionary containing checker lists containing checker function and
24 list of codes
25 @rtype dict
26 """
27 return {
28 "Call": [
29 (checkPytorchLoadSave, ("S614",)),
30 ],
31 }
32
33
34 def checkPytorchLoadSave(reportError, context, _config):
35 """
36 Function to check for the use of 'torch.load' and 'torch.save'.
37
38 Using `torch.load` with untrusted data can lead to arbitrary code
39 execution, and improper use of `torch.save` might expose sensitive
40 data or lead to data corruption.
41
42 @param reportError function to be used to report errors
43 @type func
44 @param context security context object
45 @type SecurityContext
46 @param _config dictionary with configuration data (unused)
47 @type dict
48 """
49 imported = context.isModuleImportedExact("torch")
50 qualname = context.callFunctionNameQual
51 if not imported and isinstance(qualname, str):
52 return
53
54 qualnameList = qualname.split(".")
55 func = qualnameList[-1]
56 if all(
57 [
58 "torch" in qualnameList,
59 func in ["load", "save"],
60 not context.checkCallArgValue("map_location", "cpu"),
61 ]
62 ):
63 reportError(
64 context.node.lineno - 1,
65 context.node.col_offset,
66 "S614",
67 "M",
68 "H",
69 )

eric ide

mercurial