Sun, 13 Mar 2022 19:59:03 +0100
Started implementing a vulnerability checker based on the data of the Safety DB.
# -*- coding: utf-8 -*- # Copyright (c) 2022 Detlev Offenbach <detlev@die-offenbachs.de> # """ Module implementing a Python package vulnerability checker. The vulnerability data is provided by the open Python vulnerability database <a href="https://github.com/pyupio/safety-db">Safety DB</a>. """ import contextlib import enum import json import os import time from collections import namedtuple from dataclasses import dataclass from packaging.specifiers import SpecifierSet from PyQt6.QtCore import QCoreApplication, QObject, QThread, QUrl from PyQt6.QtNetwork import QNetworkReply, QNetworkRequest from EricWidgets import EricMessageBox import Globals import Preferences Package = namedtuple("Package", ["name", "version"]) @dataclass class Vulnerability: """ Class containing the vulnerability data. """ name: str # package name spec: dict # package specification record version: str # package version cve: str # CVE ID advisory: str # CVE advisory text vulnerabilityId: str # vulnerability ID class VulnerabilityCheckError(enum.Enum): """ Class defining various vulnerability check error states. """ OK = 0 SummaryDbUnavailable = 1 FullDbUnavailable = 2 class PipVulnerabilityChecker(QObject): """ Class implementing a Python package vulnerability checker. """ def __init__(self, pip, parent=None): """ Constructor @param pip reference to the global pip interface @type Pip @param parent reference to the parent widget (defaults to None) @type QWidget (optional) """ super().__init__(parent) self.__pip = pip securityDir = os.path.join(Globals.getConfigDir(), "security") os.makedirs(securityDir, mode=0o700, exist_ok=True) self.__cacheFile = os.path.join(securityDir, "vulnerability_cache.json") if not os.path.exists(self.__cacheFile): self.__createCacheFile() def __createCacheFile(self): """ Private method to create the cache file. The cache file has the following structure. { "insecure.json": { "cachedAt": 12345678 "db": {} }, "insecure_full.json": { "cachedAt": 12345678 "db": {} }, } """ structure = { "insecure.json": { "cachedAt": 0, "db": {}, }, "insecure_full.json": { "cachedAt": 0, "db": {}, }, } with open(self.__cacheFile, "w") as f: json.dump(structure, f, indent=2) def __getDataFromCache(self, dbName): """ Private method to get the vulnerability database from the cache. @param dbName name of the vulnerability database @type str @return dictionary containing the requested vulnerability data @rtype dict """ if os.path.exists(self.__cacheFile): with open(self.__cacheFile, "r") as f: with contextlib.suppress(json.JSONDecodeError, OSError): cachedData = json.load(f) if ( dbName in cachedData and "cachedAt" in cachedData[dbName] ): cacheValidPeriod = Preferences.getPip( "VulnerabilityDbCacheValidity") if ( cachedData[dbName]["cachedAt"] + cacheValidPeriod > time.time() ): return cachedData[dbName]["db"] return {} def __writeDataToCache(self, dbName, data): """ Private method to write the vulnerability data for a database to the cache. @param dbName name of the vulnerability database @type str @param data dictionary containing the vulnerability data @type dict """ if not os.path.exists(self.__cacheFile): self.__createCacheFile() with open(self.__cacheFile, "r") as f: try: cache = json.load(f) except json.JSONDecodeError: cache = {} cache[dbName] = { "cachedAt": time.time(), "db": data, } with open(self.__cacheFile, "w") as f: json.dump(cache, f, indent=2) def __fetchVulnerabilityDatabase(self, full=False): """ Private method to get the data of the vulnerability database. If the cached data is still valid, this data will be used. Otherwise a copy of the requested database will be downloaded and cached. @param full flag indicating to get the database containing the full data set (defaults to False) @type bool (optional) @return dictionary containing the vulnerability data (full data set or just package name and version specifier) """ dbName = "insecure_full.json" if full else "insecure.json" cachedData = self.__getDataFromCache(dbName) if cachedData: return cachedData url = Preferences.getPip("VulnerabilityDbMirror") + dbName request = QNetworkRequest(QUrl(url)) reply = self.__pip.getNetworkAccessManager().get(request) while not reply.isFinished(): QCoreApplication.processEvents() QThread.msleep(100) reply.deleteLater() if reply.error() == QNetworkReply.NetworkError.NoError: data = str(reply.readAll(), Preferences.getSystem("IOEncoding"), 'replace') with contextlib.suppress(json.JSONDecodeError): data = json.loads(data) self.__writeDataToCache(dbName, data) return data EricMessageBox.critical( None, self.tr("Fetching Vulnerability Database"), self.tr("""<p>The vulnerability database <b>{0}</b> could not""" """ be loaded from <b>{1}</b>.</p><p>The vulnerability""" """ check is not available.</p>""") ) return {} def __getVulnerabilities(self, package, specifier, db): """ Private method to get the vulnerabilities for a package. @param package name of the package @type str @param specifier package specifier @type Specifier @param db vulnerability data @type dict @yield dictionary containing the vulnerability data for the package @ytype dict """ for entry in db[package]: for entrySpec in entry["specs"]: if entrySpec == specifier: yield entry def check(self, packages): """ Public method to check the given packages for vulnerabilities. @param packages list of packages @type Package @return tuple containing an error status and the list of vulnerable packages detected @rtype tuple of (VulnerabilityCheckError, list of Vulnerability) """ db = self.__fetchVulnerabilityDatabase() if not db: return VulnerabilityCheckError.SummaryDbUnavailable, [] fullDb = None vulnerablePackages = frozenset(db.keys()) vulnerabilities = [] # TODO: fill this list for package in packages: # normalize the package name, the safety-db is converting # underscores to dashes and uses lowercase name = package.name.replace("_", "-").lower() if name in vulnerablePackages: # we have a candidate here, build the spec set for specifier in db[name]: specifierSet = SpecifierSet(specifiers=specifier) if specifierSet.contains(package.version): if not fullDb: fullDb = self.__fetchVulnerabilityDatabase( full=True) for data in self.__getVulnerabilities( package=name, specifier=specifier, db=fullDb ): vulnarabilityId = ( data.get("id").replace("pyup.io-", "") ) cveId = data.get("cve") if cveId: cveId = cveId.split(",", 1)[0].strip() return VulnerabilityCheckError.OK, vulnerabilities