eric7/PipInterface/PipVulnerabilityChecker.py

branch
eric7
changeset 8977
663521af48b2
child 8978
38c3ddf21537
equal deleted inserted replaced
8976:ca442cd49b9e 8977:663521af48b2
1 # -*- coding: utf-8 -*-
2
3 # Copyright (c) 2022 Detlev Offenbach <detlev@die-offenbachs.de>
4 #
5
6 """
7 Module implementing a Python package vulnerability checker.
8
9 The vulnerability data is provided by the open Python vulnerability database
10 <a href="https://github.com/pyupio/safety-db">Safety DB</a>.
11 """
12
13 import contextlib
14 import enum
15 import json
16 import os
17 import time
18 from collections import namedtuple
19 from dataclasses import dataclass
20
21 from packaging.specifiers import SpecifierSet
22
23 from PyQt6.QtCore import QCoreApplication, QObject, QThread, QUrl
24 from PyQt6.QtNetwork import QNetworkReply, QNetworkRequest
25
26 from EricWidgets import EricMessageBox
27
28 import Globals
29 import Preferences
30
31 Package = namedtuple("Package", ["name", "version"])
32
33
34 @dataclass
35 class Vulnerability:
36 """
37 Class containing the vulnerability data.
38 """
39 name: str # package name
40 spec: dict # package specification record
41 version: str # package version
42 cve: str # CVE ID
43 advisory: str # CVE advisory text
44 vulnerabilityId: str # vulnerability ID
45
46
47 class VulnerabilityCheckError(enum.Enum):
48 """
49 Class defining various vulnerability check error states.
50 """
51 OK = 0
52 SummaryDbUnavailable = 1
53 FullDbUnavailable = 2
54
55
56 class PipVulnerabilityChecker(QObject):
57 """
58 Class implementing a Python package vulnerability checker.
59 """
60 def __init__(self, pip, parent=None):
61 """
62 Constructor
63
64 @param pip reference to the global pip interface
65 @type Pip
66 @param parent reference to the parent widget (defaults to None)
67 @type QWidget (optional)
68 """
69 super().__init__(parent)
70
71 self.__pip = pip
72
73 securityDir = os.path.join(Globals.getConfigDir(), "security")
74 os.makedirs(securityDir, mode=0o700, exist_ok=True)
75 self.__cacheFile = os.path.join(securityDir,
76 "vulnerability_cache.json")
77 if not os.path.exists(self.__cacheFile):
78 self.__createCacheFile()
79
80 def __createCacheFile(self):
81 """
82 Private method to create the cache file.
83
84 The cache file has the following structure.
85 {
86 "insecure.json": {
87 "cachedAt": 12345678
88 "db": {}
89 },
90 "insecure_full.json": {
91 "cachedAt": 12345678
92 "db": {}
93 },
94 }
95 """
96 structure = {
97 "insecure.json": {
98 "cachedAt": 0,
99 "db": {},
100 },
101 "insecure_full.json": {
102 "cachedAt": 0,
103 "db": {},
104 },
105 }
106 with open(self.__cacheFile, "w") as f:
107 json.dump(structure, f, indent=2)
108
109 def __getDataFromCache(self, dbName):
110 """
111 Private method to get the vulnerability database from the cache.
112
113 @param dbName name of the vulnerability database
114 @type str
115 @return dictionary containing the requested vulnerability data
116 @rtype dict
117 """
118 if os.path.exists(self.__cacheFile):
119 with open(self.__cacheFile, "r") as f:
120 with contextlib.suppress(json.JSONDecodeError, OSError):
121 cachedData = json.load(f)
122 if (
123 dbName in cachedData and
124 "cachedAt" in cachedData[dbName]
125 ):
126 cacheValidPeriod = Preferences.getPip(
127 "VulnerabilityDbCacheValidity")
128 if (
129 cachedData[dbName]["cachedAt"] + cacheValidPeriod >
130 time.time()
131 ):
132 return cachedData[dbName]["db"]
133
134 return {}
135
136 def __writeDataToCache(self, dbName, data):
137 """
138 Private method to write the vulnerability data for a database to the
139 cache.
140
141 @param dbName name of the vulnerability database
142 @type str
143 @param data dictionary containing the vulnerability data
144 @type dict
145 """
146 if not os.path.exists(self.__cacheFile):
147 self.__createCacheFile()
148
149 with open(self.__cacheFile, "r") as f:
150 try:
151 cache = json.load(f)
152 except json.JSONDecodeError:
153 cache = {}
154
155 cache[dbName] = {
156 "cachedAt": time.time(),
157 "db": data,
158 }
159 with open(self.__cacheFile, "w") as f:
160 json.dump(cache, f, indent=2)
161
162 def __fetchVulnerabilityDatabase(self, full=False):
163 """
164 Private method to get the data of the vulnerability database.
165
166 If the cached data is still valid, this data will be used.
167 Otherwise a copy of the requested database will be downloaded
168 and cached.
169
170 @param full flag indicating to get the database containing the full
171 data set (defaults to False)
172 @type bool (optional)
173 @return dictionary containing the vulnerability data (full data set or
174 just package name and version specifier)
175 """
176 dbName = "insecure_full.json" if full else "insecure.json"
177
178 cachedData = self.__getDataFromCache(dbName)
179 if cachedData:
180 return cachedData
181
182 url = Preferences.getPip("VulnerabilityDbMirror") + dbName
183 request = QNetworkRequest(QUrl(url))
184 reply = self.__pip.getNetworkAccessManager().get(request)
185 while not reply.isFinished():
186 QCoreApplication.processEvents()
187 QThread.msleep(100)
188
189 reply.deleteLater()
190 if reply.error() == QNetworkReply.NetworkError.NoError:
191 data = str(reply.readAll(),
192 Preferences.getSystem("IOEncoding"),
193 'replace')
194 with contextlib.suppress(json.JSONDecodeError):
195 data = json.loads(data)
196 self.__writeDataToCache(dbName, data)
197 return data
198
199 EricMessageBox.critical(
200 None,
201 self.tr("Fetching Vulnerability Database"),
202 self.tr("""<p>The vulnerability database <b>{0}</b> could not"""
203 """ be loaded from <b>{1}</b>.</p><p>The vulnerability"""
204 """ check is not available.</p>""")
205 )
206 return {}
207
208 def __getVulnerabilities(self, package, specifier, db):
209 """
210 Private method to get the vulnerabilities for a package.
211
212 @param package name of the package
213 @type str
214 @param specifier package specifier
215 @type Specifier
216 @param db vulnerability data
217 @type dict
218 @yield dictionary containing the vulnerability data for the package
219 @ytype dict
220 """
221 for entry in db[package]:
222 for entrySpec in entry["specs"]:
223 if entrySpec == specifier:
224 yield entry
225
226 def check(self, packages):
227 """
228 Public method to check the given packages for vulnerabilities.
229
230 @param packages list of packages
231 @type Package
232 @return tuple containing an error status and the list of vulnerable
233 packages detected
234 @rtype tuple of (VulnerabilityCheckError, list of Vulnerability)
235 """
236 db = self.__fetchVulnerabilityDatabase()
237 if not db:
238 return VulnerabilityCheckError.SummaryDbUnavailable, []
239
240 fullDb = None
241 vulnerablePackages = frozenset(db.keys())
242 vulnerabilities = [] # TODO: fill this list
243
244 for package in packages:
245 # normalize the package name, the safety-db is converting
246 # underscores to dashes and uses lowercase
247 name = package.name.replace("_", "-").lower()
248
249 if name in vulnerablePackages:
250 # we have a candidate here, build the spec set
251 for specifier in db[name]:
252 specifierSet = SpecifierSet(specifiers=specifier)
253 if specifierSet.contains(package.version):
254 if not fullDb:
255 fullDb = self.__fetchVulnerabilityDatabase(
256 full=True)
257 for data in self.__getVulnerabilities(
258 package=name, specifier=specifier, db=fullDb
259 ):
260 vulnarabilityId = (
261 data.get("id").replace("pyup.io-", "")
262 )
263 cveId = data.get("cve")
264 if cveId:
265 cveId = cveId.split(",", 1)[0].strip()
266
267 return VulnerabilityCheckError.OK, vulnerabilities

eric ide

mercurial