Sun, 03 Dec 2023 19:46:34 +0100
Corrected some uses of dict.keys(), dict.values() and dict.items().
# -*- coding: utf-8 -*- # Copyright (c) 2022 - 2023 Detlev Offenbach <detlev@die-offenbachs.de> # """ Module implementing a Python package vulnerability checker. The vulnerability data is provided by the open Python vulnerability database <a href="https://github.com/pyupio/safety-db">Safety DB</a>. """ import collections import contextlib import enum import json import os import time from dataclasses import dataclass from packaging.specifiers import SpecifierSet from PyQt6.QtCore import QCoreApplication, QObject, QThread, QUrl from PyQt6.QtNetwork import QNetworkReply, QNetworkRequest from eric7 import Globals, Preferences from eric7.EricWidgets import EricMessageBox @dataclass class Package: """ Class containing the package data. """ name: str # package name version: str # version @dataclass class Vulnerability: """ Class containing the vulnerability data. """ name: str # package name spec: dict # package specification record version: str # package version cve: str # CVE ID advisory: str # CVE advisory text vulnerabilityId: str # vulnerability ID class VulnerabilityCheckError(enum.Enum): """ Class defining various vulnerability check error states. """ OK = 0 SummaryDbUnavailable = 1 FullDbUnavailable = 2 class PipVulnerabilityChecker(QObject): """ Class implementing a Python package vulnerability checker. """ FullDbFile = "insecure_full.json" SummaryDbFile = "insecure.json" def __init__(self, pip, parent=None): """ Constructor @param pip reference to the global pip interface @type Pip @param parent reference to the parent widget (defaults to None) @type QWidget (optional) """ super().__init__(parent) self.__pip = pip securityDir = os.path.join(Globals.getConfigDir(), "security") os.makedirs(securityDir, mode=0o700, exist_ok=True) self.__cacheFile = os.path.join(securityDir, "vulnerability_cache.json") if not os.path.exists(self.__cacheFile): self.__createCacheFile() def __createCacheFile(self): """ Private method to create the cache file. The cache file has the following structure. { "insecure.json": { "cachedAt": 12345678 "db": {} }, "insecure_full.json": { "cachedAt": 12345678 "db": {} }, } """ structure = { "insecure.json": { "cachedAt": 0, "db": {}, }, "insecure_full.json": { "cachedAt": 0, "db": {}, }, } with open(self.__cacheFile, "w") as f: json.dump(structure, f, indent=2) def __getDataFromCache(self, dbName): """ Private method to get the vulnerability database from the cache. @param dbName name of the vulnerability database @type str @return dictionary containing the requested vulnerability data @rtype dict """ if os.path.exists(self.__cacheFile): with open(self.__cacheFile, "r") as f: # noqa: Y117 with contextlib.suppress(json.JSONDecodeError, OSError): cachedData = json.load(f) if dbName in cachedData and "cachedAt" in cachedData[dbName]: cacheValidPeriod = Preferences.getPip( "VulnerabilityDbCacheValidity" ) if ( cachedData[dbName]["cachedAt"] + cacheValidPeriod > time.time() ): return cachedData[dbName]["db"] return {} def __writeDataToCache(self, dbName, data): """ Private method to write the vulnerability data for a database to the cache. @param dbName name of the vulnerability database @type str @param data dictionary containing the vulnerability data @type dict """ if not os.path.exists(self.__cacheFile): self.__createCacheFile() with open(self.__cacheFile, "r") as f: try: cache = json.load(f) except json.JSONDecodeError: cache = {} cache[dbName] = { "cachedAt": time.time(), "db": data, } with open(self.__cacheFile, "w") as f: json.dump(cache, f, indent=2) def __fetchVulnerabilityDatabase(self, full=False, forceUpdate=False): """ Private method to get the data of the vulnerability database. If the cached data is still valid, this data will be used. Otherwise a copy of the requested database will be downloaded and cached. @param full flag indicating to get the database containing the full data set (defaults to False) @type bool (optional) @param forceUpdate flag indicating an update of the cache is required (defaults to False) @type bool (optional) @return dictionary containing the vulnerability data (full data set or just package name and version specifier) """ dbName = ( PipVulnerabilityChecker.FullDbFile if full else PipVulnerabilityChecker.SummaryDbFile ) if not forceUpdate: cachedData = self.__getDataFromCache(dbName) if cachedData: return cachedData url = Preferences.getPip("VulnerabilityDbMirror") + dbName request = QNetworkRequest(QUrl(url)) reply = self.__pip.getNetworkAccessManager().get(request) while not reply.isFinished(): QCoreApplication.processEvents() QThread.msleep(100) reply.deleteLater() if reply.error() == QNetworkReply.NetworkError.NoError: data = str(reply.readAll(), Preferences.getSystem("IOEncoding"), "replace") with contextlib.suppress(json.JSONDecodeError): data = json.loads(data) self.__writeDataToCache(dbName, data) return data EricMessageBox.critical( None, self.tr("Fetching Vulnerability Database"), self.tr( """<p>The vulnerability database <b>{0}</b> could not""" """ be loaded from <b>{1}</b>.</p><p>The vulnerability""" """ check is not available.</p>""" ).format(dbName, Preferences.getPip("VulnerabilityDbMirror")), ) return {} def __getVulnerabilities(self, package, specifier, db): """ Private method to get the vulnerabilities for a package. @param package name of the package @type str @param specifier package specifier @type Specifier @param db vulnerability data @type dict @yield dictionary containing the vulnerability data for the package @ytype dict """ for entry in db[package]: for entrySpec in entry["specs"]: if entrySpec == specifier: yield entry def check(self, packages): """ Public method to check the given packages for vulnerabilities. @param packages list of packages @type Package @return tuple containing an error status and a dictionary containing detected vulnerable packages keyed by package name @rtype tuple of (VulnerabilityCheckError, list of Vulnerability) """ db = self.__fetchVulnerabilityDatabase() if not db: return VulnerabilityCheckError.SummaryDbUnavailable, [] fullDb = None vulnerablePackages = frozenset(db) vulnerabilities = collections.defaultdict(list) for package in packages: # normalize the package name, the safety-db is converting # underscores to dashes and uses lowercase name = package.name.replace("_", "-").lower() if name in vulnerablePackages: # we have a candidate here, build the spec set for specifier in db[name]: specifierSet = SpecifierSet(specifiers=specifier) if specifierSet.contains(package.version): if not fullDb: fullDb = self.__fetchVulnerabilityDatabase(full=True) for data in self.__getVulnerabilities( package=name, specifier=specifier, db=fullDb ): vulnarabilityId = data.get("id").replace("pyup.io-", "") cveId = data.get("cve", "") if cveId: cveId = cveId.split(",", 1)[0].strip() vulnerabilities[package.name].append( Vulnerability( name=name, spec=specifier, version=package.version, cve=cveId, advisory=data.get("advisory", ""), vulnerabilityId=vulnarabilityId, ) ) return VulnerabilityCheckError.OK, vulnerabilities def updateVulnerabilityDb(self): """ Public method to update the cache of the vulnerability databases. """ self.__fetchVulnerabilityDatabase(full=False, forceUpdate=True) self.__fetchVulnerabilityDatabase(full=True, forceUpdate=True)