|
1 # -*- coding: utf-8 -*- |
|
2 |
|
3 # Copyright (c) 2022 Detlev Offenbach <detlev@die-offenbachs.de> |
|
4 # |
|
5 |
|
6 """ |
|
7 Module implementing a Python package vulnerability checker. |
|
8 |
|
9 The vulnerability data is provided by the open Python vulnerability database |
|
10 <a href="https://github.com/pyupio/safety-db">Safety DB</a>. |
|
11 """ |
|
12 |
|
13 import contextlib |
|
14 import enum |
|
15 import json |
|
16 import os |
|
17 import time |
|
18 from collections import namedtuple |
|
19 from dataclasses import dataclass |
|
20 |
|
21 from packaging.specifiers import SpecifierSet |
|
22 |
|
23 from PyQt6.QtCore import QCoreApplication, QObject, QThread, QUrl |
|
24 from PyQt6.QtNetwork import QNetworkReply, QNetworkRequest |
|
25 |
|
26 from EricWidgets import EricMessageBox |
|
27 |
|
28 import Globals |
|
29 import Preferences |
|
30 |
|
31 Package = namedtuple("Package", ["name", "version"]) |
|
32 |
|
33 |
|
34 @dataclass |
|
35 class Vulnerability: |
|
36 """ |
|
37 Class containing the vulnerability data. |
|
38 """ |
|
39 name: str # package name |
|
40 spec: dict # package specification record |
|
41 version: str # package version |
|
42 cve: str # CVE ID |
|
43 advisory: str # CVE advisory text |
|
44 vulnerabilityId: str # vulnerability ID |
|
45 |
|
46 |
|
47 class VulnerabilityCheckError(enum.Enum): |
|
48 """ |
|
49 Class defining various vulnerability check error states. |
|
50 """ |
|
51 OK = 0 |
|
52 SummaryDbUnavailable = 1 |
|
53 FullDbUnavailable = 2 |
|
54 |
|
55 |
|
56 class PipVulnerabilityChecker(QObject): |
|
57 """ |
|
58 Class implementing a Python package vulnerability checker. |
|
59 """ |
|
60 def __init__(self, pip, parent=None): |
|
61 """ |
|
62 Constructor |
|
63 |
|
64 @param pip reference to the global pip interface |
|
65 @type Pip |
|
66 @param parent reference to the parent widget (defaults to None) |
|
67 @type QWidget (optional) |
|
68 """ |
|
69 super().__init__(parent) |
|
70 |
|
71 self.__pip = pip |
|
72 |
|
73 securityDir = os.path.join(Globals.getConfigDir(), "security") |
|
74 os.makedirs(securityDir, mode=0o700, exist_ok=True) |
|
75 self.__cacheFile = os.path.join(securityDir, |
|
76 "vulnerability_cache.json") |
|
77 if not os.path.exists(self.__cacheFile): |
|
78 self.__createCacheFile() |
|
79 |
|
80 def __createCacheFile(self): |
|
81 """ |
|
82 Private method to create the cache file. |
|
83 |
|
84 The cache file has the following structure. |
|
85 { |
|
86 "insecure.json": { |
|
87 "cachedAt": 12345678 |
|
88 "db": {} |
|
89 }, |
|
90 "insecure_full.json": { |
|
91 "cachedAt": 12345678 |
|
92 "db": {} |
|
93 }, |
|
94 } |
|
95 """ |
|
96 structure = { |
|
97 "insecure.json": { |
|
98 "cachedAt": 0, |
|
99 "db": {}, |
|
100 }, |
|
101 "insecure_full.json": { |
|
102 "cachedAt": 0, |
|
103 "db": {}, |
|
104 }, |
|
105 } |
|
106 with open(self.__cacheFile, "w") as f: |
|
107 json.dump(structure, f, indent=2) |
|
108 |
|
109 def __getDataFromCache(self, dbName): |
|
110 """ |
|
111 Private method to get the vulnerability database from the cache. |
|
112 |
|
113 @param dbName name of the vulnerability database |
|
114 @type str |
|
115 @return dictionary containing the requested vulnerability data |
|
116 @rtype dict |
|
117 """ |
|
118 if os.path.exists(self.__cacheFile): |
|
119 with open(self.__cacheFile, "r") as f: |
|
120 with contextlib.suppress(json.JSONDecodeError, OSError): |
|
121 cachedData = json.load(f) |
|
122 if ( |
|
123 dbName in cachedData and |
|
124 "cachedAt" in cachedData[dbName] |
|
125 ): |
|
126 cacheValidPeriod = Preferences.getPip( |
|
127 "VulnerabilityDbCacheValidity") |
|
128 if ( |
|
129 cachedData[dbName]["cachedAt"] + cacheValidPeriod > |
|
130 time.time() |
|
131 ): |
|
132 return cachedData[dbName]["db"] |
|
133 |
|
134 return {} |
|
135 |
|
136 def __writeDataToCache(self, dbName, data): |
|
137 """ |
|
138 Private method to write the vulnerability data for a database to the |
|
139 cache. |
|
140 |
|
141 @param dbName name of the vulnerability database |
|
142 @type str |
|
143 @param data dictionary containing the vulnerability data |
|
144 @type dict |
|
145 """ |
|
146 if not os.path.exists(self.__cacheFile): |
|
147 self.__createCacheFile() |
|
148 |
|
149 with open(self.__cacheFile, "r") as f: |
|
150 try: |
|
151 cache = json.load(f) |
|
152 except json.JSONDecodeError: |
|
153 cache = {} |
|
154 |
|
155 cache[dbName] = { |
|
156 "cachedAt": time.time(), |
|
157 "db": data, |
|
158 } |
|
159 with open(self.__cacheFile, "w") as f: |
|
160 json.dump(cache, f, indent=2) |
|
161 |
|
162 def __fetchVulnerabilityDatabase(self, full=False): |
|
163 """ |
|
164 Private method to get the data of the vulnerability database. |
|
165 |
|
166 If the cached data is still valid, this data will be used. |
|
167 Otherwise a copy of the requested database will be downloaded |
|
168 and cached. |
|
169 |
|
170 @param full flag indicating to get the database containing the full |
|
171 data set (defaults to False) |
|
172 @type bool (optional) |
|
173 @return dictionary containing the vulnerability data (full data set or |
|
174 just package name and version specifier) |
|
175 """ |
|
176 dbName = "insecure_full.json" if full else "insecure.json" |
|
177 |
|
178 cachedData = self.__getDataFromCache(dbName) |
|
179 if cachedData: |
|
180 return cachedData |
|
181 |
|
182 url = Preferences.getPip("VulnerabilityDbMirror") + dbName |
|
183 request = QNetworkRequest(QUrl(url)) |
|
184 reply = self.__pip.getNetworkAccessManager().get(request) |
|
185 while not reply.isFinished(): |
|
186 QCoreApplication.processEvents() |
|
187 QThread.msleep(100) |
|
188 |
|
189 reply.deleteLater() |
|
190 if reply.error() == QNetworkReply.NetworkError.NoError: |
|
191 data = str(reply.readAll(), |
|
192 Preferences.getSystem("IOEncoding"), |
|
193 'replace') |
|
194 with contextlib.suppress(json.JSONDecodeError): |
|
195 data = json.loads(data) |
|
196 self.__writeDataToCache(dbName, data) |
|
197 return data |
|
198 |
|
199 EricMessageBox.critical( |
|
200 None, |
|
201 self.tr("Fetching Vulnerability Database"), |
|
202 self.tr("""<p>The vulnerability database <b>{0}</b> could not""" |
|
203 """ be loaded from <b>{1}</b>.</p><p>The vulnerability""" |
|
204 """ check is not available.</p>""") |
|
205 ) |
|
206 return {} |
|
207 |
|
208 def __getVulnerabilities(self, package, specifier, db): |
|
209 """ |
|
210 Private method to get the vulnerabilities for a package. |
|
211 |
|
212 @param package name of the package |
|
213 @type str |
|
214 @param specifier package specifier |
|
215 @type Specifier |
|
216 @param db vulnerability data |
|
217 @type dict |
|
218 @yield dictionary containing the vulnerability data for the package |
|
219 @ytype dict |
|
220 """ |
|
221 for entry in db[package]: |
|
222 for entrySpec in entry["specs"]: |
|
223 if entrySpec == specifier: |
|
224 yield entry |
|
225 |
|
226 def check(self, packages): |
|
227 """ |
|
228 Public method to check the given packages for vulnerabilities. |
|
229 |
|
230 @param packages list of packages |
|
231 @type Package |
|
232 @return tuple containing an error status and the list of vulnerable |
|
233 packages detected |
|
234 @rtype tuple of (VulnerabilityCheckError, list of Vulnerability) |
|
235 """ |
|
236 db = self.__fetchVulnerabilityDatabase() |
|
237 if not db: |
|
238 return VulnerabilityCheckError.SummaryDbUnavailable, [] |
|
239 |
|
240 fullDb = None |
|
241 vulnerablePackages = frozenset(db.keys()) |
|
242 vulnerabilities = [] # TODO: fill this list |
|
243 |
|
244 for package in packages: |
|
245 # normalize the package name, the safety-db is converting |
|
246 # underscores to dashes and uses lowercase |
|
247 name = package.name.replace("_", "-").lower() |
|
248 |
|
249 if name in vulnerablePackages: |
|
250 # we have a candidate here, build the spec set |
|
251 for specifier in db[name]: |
|
252 specifierSet = SpecifierSet(specifiers=specifier) |
|
253 if specifierSet.contains(package.version): |
|
254 if not fullDb: |
|
255 fullDb = self.__fetchVulnerabilityDatabase( |
|
256 full=True) |
|
257 for data in self.__getVulnerabilities( |
|
258 package=name, specifier=specifier, db=fullDb |
|
259 ): |
|
260 vulnarabilityId = ( |
|
261 data.get("id").replace("pyup.io-", "") |
|
262 ) |
|
263 cveId = data.get("cve") |
|
264 if cveId: |
|
265 cveId = cveId.split(",", 1)[0].strip() |
|
266 |
|
267 return VulnerabilityCheckError.OK, vulnerabilities |