--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/eric7/PipInterface/PipVulnerabilityChecker.py Sun Mar 13 19:59:03 2022 +0100 @@ -0,0 +1,267 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2022 Detlev Offenbach <detlev@die-offenbachs.de> +# + +""" +Module implementing a Python package vulnerability checker. + +The vulnerability data is provided by the open Python vulnerability database +<a href="https://github.com/pyupio/safety-db">Safety DB</a>. +""" + +import contextlib +import enum +import json +import os +import time +from collections import namedtuple +from dataclasses import dataclass + +from packaging.specifiers import SpecifierSet + +from PyQt6.QtCore import QCoreApplication, QObject, QThread, QUrl +from PyQt6.QtNetwork import QNetworkReply, QNetworkRequest + +from EricWidgets import EricMessageBox + +import Globals +import Preferences + +Package = namedtuple("Package", ["name", "version"]) + + +@dataclass +class Vulnerability: + """ + Class containing the vulnerability data. + """ + name: str # package name + spec: dict # package specification record + version: str # package version + cve: str # CVE ID + advisory: str # CVE advisory text + vulnerabilityId: str # vulnerability ID + + +class VulnerabilityCheckError(enum.Enum): + """ + Class defining various vulnerability check error states. + """ + OK = 0 + SummaryDbUnavailable = 1 + FullDbUnavailable = 2 + + +class PipVulnerabilityChecker(QObject): + """ + Class implementing a Python package vulnerability checker. + """ + def __init__(self, pip, parent=None): + """ + Constructor + + @param pip reference to the global pip interface + @type Pip + @param parent reference to the parent widget (defaults to None) + @type QWidget (optional) + """ + super().__init__(parent) + + self.__pip = pip + + securityDir = os.path.join(Globals.getConfigDir(), "security") + os.makedirs(securityDir, mode=0o700, exist_ok=True) + self.__cacheFile = os.path.join(securityDir, + "vulnerability_cache.json") + if not os.path.exists(self.__cacheFile): + self.__createCacheFile() + + def __createCacheFile(self): + """ + Private method to create the cache file. + + The cache file has the following structure. + { + "insecure.json": { + "cachedAt": 12345678 + "db": {} + }, + "insecure_full.json": { + "cachedAt": 12345678 + "db": {} + }, + } + """ + structure = { + "insecure.json": { + "cachedAt": 0, + "db": {}, + }, + "insecure_full.json": { + "cachedAt": 0, + "db": {}, + }, + } + with open(self.__cacheFile, "w") as f: + json.dump(structure, f, indent=2) + + def __getDataFromCache(self, dbName): + """ + Private method to get the vulnerability database from the cache. + + @param dbName name of the vulnerability database + @type str + @return dictionary containing the requested vulnerability data + @rtype dict + """ + if os.path.exists(self.__cacheFile): + with open(self.__cacheFile, "r") as f: + with contextlib.suppress(json.JSONDecodeError, OSError): + cachedData = json.load(f) + if ( + dbName in cachedData and + "cachedAt" in cachedData[dbName] + ): + cacheValidPeriod = Preferences.getPip( + "VulnerabilityDbCacheValidity") + if ( + cachedData[dbName]["cachedAt"] + cacheValidPeriod > + time.time() + ): + return cachedData[dbName]["db"] + + return {} + + def __writeDataToCache(self, dbName, data): + """ + Private method to write the vulnerability data for a database to the + cache. + + @param dbName name of the vulnerability database + @type str + @param data dictionary containing the vulnerability data + @type dict + """ + if not os.path.exists(self.__cacheFile): + self.__createCacheFile() + + with open(self.__cacheFile, "r") as f: + try: + cache = json.load(f) + except json.JSONDecodeError: + cache = {} + + cache[dbName] = { + "cachedAt": time.time(), + "db": data, + } + with open(self.__cacheFile, "w") as f: + json.dump(cache, f, indent=2) + + def __fetchVulnerabilityDatabase(self, full=False): + """ + Private method to get the data of the vulnerability database. + + If the cached data is still valid, this data will be used. + Otherwise a copy of the requested database will be downloaded + and cached. + + @param full flag indicating to get the database containing the full + data set (defaults to False) + @type bool (optional) + @return dictionary containing the vulnerability data (full data set or + just package name and version specifier) + """ + dbName = "insecure_full.json" if full else "insecure.json" + + cachedData = self.__getDataFromCache(dbName) + if cachedData: + return cachedData + + url = Preferences.getPip("VulnerabilityDbMirror") + dbName + request = QNetworkRequest(QUrl(url)) + reply = self.__pip.getNetworkAccessManager().get(request) + while not reply.isFinished(): + QCoreApplication.processEvents() + QThread.msleep(100) + + reply.deleteLater() + if reply.error() == QNetworkReply.NetworkError.NoError: + data = str(reply.readAll(), + Preferences.getSystem("IOEncoding"), + 'replace') + with contextlib.suppress(json.JSONDecodeError): + data = json.loads(data) + self.__writeDataToCache(dbName, data) + return data + + EricMessageBox.critical( + None, + self.tr("Fetching Vulnerability Database"), + self.tr("""<p>The vulnerability database <b>{0}</b> could not""" + """ be loaded from <b>{1}</b>.</p><p>The vulnerability""" + """ check is not available.</p>""") + ) + return {} + + def __getVulnerabilities(self, package, specifier, db): + """ + Private method to get the vulnerabilities for a package. + + @param package name of the package + @type str + @param specifier package specifier + @type Specifier + @param db vulnerability data + @type dict + @yield dictionary containing the vulnerability data for the package + @ytype dict + """ + for entry in db[package]: + for entrySpec in entry["specs"]: + if entrySpec == specifier: + yield entry + + def check(self, packages): + """ + Public method to check the given packages for vulnerabilities. + + @param packages list of packages + @type Package + @return tuple containing an error status and the list of vulnerable + packages detected + @rtype tuple of (VulnerabilityCheckError, list of Vulnerability) + """ + db = self.__fetchVulnerabilityDatabase() + if not db: + return VulnerabilityCheckError.SummaryDbUnavailable, [] + + fullDb = None + vulnerablePackages = frozenset(db.keys()) + vulnerabilities = [] # TODO: fill this list + + for package in packages: + # normalize the package name, the safety-db is converting + # underscores to dashes and uses lowercase + name = package.name.replace("_", "-").lower() + + if name in vulnerablePackages: + # we have a candidate here, build the spec set + for specifier in db[name]: + specifierSet = SpecifierSet(specifiers=specifier) + if specifierSet.contains(package.version): + if not fullDb: + fullDb = self.__fetchVulnerabilityDatabase( + full=True) + for data in self.__getVulnerabilities( + package=name, specifier=specifier, db=fullDb + ): + vulnarabilityId = ( + data.get("id").replace("pyup.io-", "") + ) + cveId = data.get("cve") + if cveId: + cveId = cveId.split(",", 1)[0].strip() + + return VulnerabilityCheckError.OK, vulnerabilities