src/eric7/Plugins/CheckerPlugins/CodeStyleChecker/Security/Checks/pytorchLoad.py

branch
eric7
changeset 11297
2c773823fb7d
equal deleted inserted replaced
11296:2894ef54fe84 11297:2c773823fb7d
1 # -*- coding: utf-8 -*-
2
3 # Copyright (c) 2024 - 2025 Detlev Offenbach <detlev@die-offenbachs.de>
4 #
5
6 """
7 Module implementing checks for the use of 'torch.load' and 'torch.save'.
8 """
9
10 #
11 # This is a modified version of the one found in the bandit package.
12 #
13 # Original Copyright (c) 2024 Stacklok, Inc.
14 #
15 # SPDX-License-Identifier: Apache-2.0
16 #
17
18
19 def getChecks():
20 """
21 Public method to get a dictionary with checks handled by this module.
22
23 @return dictionary containing checker lists containing checker function and
24 list of codes
25 @rtype dict
26 """
27 return {
28 "Call": [
29 (checkPytorchLoad, ("S-614",)),
30 ],
31 }
32
33
34 def checkPytorchLoad(reportError, context, _config):
35 """
36 Function to check for the use of 'torch.load'.
37
38 Using `torch.load` with untrusted data can lead to arbitrary code
39 execution. The safe alternative is to use `weights_only=True` or
40 the safetensors library.
41
42 @param reportError function to be used to report errors
43 @type func
44 @param context security context object
45 @type SecurityContext
46 @param _config dictionary with configuration data (unused)
47 @type dict
48 """
49 imported = context.isModuleImportedExact("torch")
50 qualname = context.callFunctionNameQual
51 if not imported and isinstance(qualname, str):
52 return
53
54 qualnameList = qualname.split(".")
55 func = qualnameList[-1]
56 if all(
57 [
58 "torch" in qualnameList,
59 func == "load",
60 ]
61 ):
62 # For torch.load, check if weights_only=True is specified
63 weightsOnly = context.getCallArgValue("weights_only")
64 if weightsOnly == "True" or weightsOnly is True:
65 return
66
67 reportError(
68 context.node.lineno - 1,
69 context.node.col_offset,
70 "S-614",
71 "M",
72 "H",
73 )

eric ide

mercurial