eric7/PipInterface/PipVulnerabilityChecker.py

branch
eric7
changeset 8977
663521af48b2
child 8978
38c3ddf21537
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/eric7/PipInterface/PipVulnerabilityChecker.py	Sun Mar 13 19:59:03 2022 +0100
@@ -0,0 +1,267 @@
+# -*- coding: utf-8 -*-
+
+# Copyright (c) 2022 Detlev Offenbach <detlev@die-offenbachs.de>
+#
+
+"""
+Module implementing a Python package vulnerability checker.
+
+The vulnerability data is provided by the open Python vulnerability database
+<a href="https://github.com/pyupio/safety-db">Safety DB</a>.
+"""
+
+import contextlib
+import enum
+import json
+import os
+import time
+from collections import namedtuple
+from dataclasses import dataclass
+
+from packaging.specifiers import SpecifierSet
+
+from PyQt6.QtCore import QCoreApplication, QObject, QThread, QUrl
+from PyQt6.QtNetwork import QNetworkReply, QNetworkRequest
+
+from EricWidgets import EricMessageBox
+
+import Globals
+import Preferences
+
+Package = namedtuple("Package", ["name", "version"])
+
+
+@dataclass
+class Vulnerability:
+    """
+    Class containing the vulnerability data.
+    """
+    name: str               # package name
+    spec: dict              # package specification record
+    version: str            # package version
+    cve: str                # CVE ID
+    advisory: str           # CVE advisory text
+    vulnerabilityId: str    # vulnerability ID
+
+
+class VulnerabilityCheckError(enum.Enum):
+    """
+    Class defining various vulnerability check error states.
+    """
+    OK = 0
+    SummaryDbUnavailable = 1
+    FullDbUnavailable = 2
+
+
+class PipVulnerabilityChecker(QObject):
+    """
+    Class implementing a Python package vulnerability checker.
+    """
+    def __init__(self, pip, parent=None):
+        """
+        Constructor
+        
+        @param pip reference to the global pip interface
+        @type Pip
+        @param parent reference to the parent widget (defaults to None)
+        @type QWidget (optional)
+        """
+        super().__init__(parent)
+        
+        self.__pip = pip
+        
+        securityDir = os.path.join(Globals.getConfigDir(), "security")
+        os.makedirs(securityDir, mode=0o700, exist_ok=True)
+        self.__cacheFile = os.path.join(securityDir,
+                                        "vulnerability_cache.json")
+        if not os.path.exists(self.__cacheFile):
+            self.__createCacheFile()
+    
+    def __createCacheFile(self):
+        """
+        Private method to create the cache file.
+        
+        The cache file has the following structure.
+        {
+          "insecure.json": {
+              "cachedAt": 12345678
+              "db": {}
+          },
+          "insecure_full.json": {
+              "cachedAt": 12345678
+              "db": {}
+          },
+        }
+        """
+        structure = {
+            "insecure.json": {
+                "cachedAt": 0,
+                "db": {},
+            },
+            "insecure_full.json": {
+                "cachedAt": 0,
+                "db": {},
+            },
+        }
+        with open(self.__cacheFile, "w") as f:
+            json.dump(structure, f, indent=2)
+    
+    def __getDataFromCache(self, dbName):
+        """
+        Private method to get the vulnerability database from the cache.
+        
+        @param dbName name of the vulnerability database
+        @type str
+        @return dictionary containing the requested vulnerability data
+        @rtype dict
+        """
+        if os.path.exists(self.__cacheFile):
+            with open(self.__cacheFile, "r") as f:
+                with contextlib.suppress(json.JSONDecodeError, OSError):
+                    cachedData = json.load(f)
+                    if (
+                        dbName in cachedData and
+                        "cachedAt" in cachedData[dbName]
+                    ):
+                        cacheValidPeriod = Preferences.getPip(
+                            "VulnerabilityDbCacheValidity")
+                        if (
+                            cachedData[dbName]["cachedAt"] + cacheValidPeriod >
+                            time.time()
+                        ):
+                            return cachedData[dbName]["db"]
+        
+        return {}
+    
+    def __writeDataToCache(self, dbName, data):
+        """
+        Private method to write the vulnerability data for a database to the
+        cache.
+        
+        @param dbName name of the vulnerability database
+        @type str
+        @param data dictionary containing the vulnerability data
+        @type dict
+        """
+        if not os.path.exists(self.__cacheFile):
+            self.__createCacheFile()
+        
+        with open(self.__cacheFile, "r") as f:
+            try:
+                cache = json.load(f)
+            except json.JSONDecodeError:
+                cache = {}
+        
+        cache[dbName] = {
+            "cachedAt": time.time(),
+            "db": data,
+        }
+        with open(self.__cacheFile, "w") as f:
+            json.dump(cache, f, indent=2)
+    
+    def __fetchVulnerabilityDatabase(self, full=False):
+        """
+        Private method to get the data of the vulnerability database.
+        
+        If the cached data is still valid, this data will be used.
+        Otherwise a copy of the requested database will be downloaded
+        and cached.
+        
+        @param full flag indicating to get the database containing the full
+            data set (defaults to False)
+        @type bool (optional)
+        @return dictionary containing the vulnerability data (full data set or
+            just package name and version specifier)
+        """
+        dbName = "insecure_full.json" if full else "insecure.json"
+        
+        cachedData = self.__getDataFromCache(dbName)
+        if cachedData:
+            return cachedData
+        
+        url = Preferences.getPip("VulnerabilityDbMirror") + dbName
+        request = QNetworkRequest(QUrl(url))
+        reply = self.__pip.getNetworkAccessManager().get(request)
+        while not reply.isFinished():
+            QCoreApplication.processEvents()
+            QThread.msleep(100)
+        
+        reply.deleteLater()
+        if reply.error() == QNetworkReply.NetworkError.NoError:
+            data = str(reply.readAll(),
+                       Preferences.getSystem("IOEncoding"),
+                       'replace')
+            with contextlib.suppress(json.JSONDecodeError):
+                data = json.loads(data)
+                self.__writeDataToCache(dbName, data)
+                return data
+        
+        EricMessageBox.critical(
+            None,
+            self.tr("Fetching Vulnerability Database"),
+            self.tr("""<p>The vulnerability database <b>{0}</b> could not"""
+                    """ be loaded from <b>{1}</b>.</p><p>The vulnerability"""
+                    """ check is not available.</p>""")
+        )
+        return {}
+    
+    def __getVulnerabilities(self, package, specifier, db):
+        """
+        Private method to get the vulnerabilities for a package.
+        
+        @param package name of the package
+        @type str
+        @param specifier package specifier
+        @type Specifier
+        @param db vulnerability data
+        @type dict
+        @yield dictionary containing the vulnerability data for the package
+        @ytype dict
+        """
+        for entry in db[package]:
+            for entrySpec in entry["specs"]:
+                if entrySpec == specifier:
+                    yield entry
+    
+    def check(self, packages):
+        """
+        Public method to check the given packages for vulnerabilities.
+        
+        @param packages list of packages
+        @type Package
+        @return tuple containing an error status and the list of vulnerable
+            packages detected
+        @rtype tuple of (VulnerabilityCheckError, list of Vulnerability)
+        """
+        db = self.__fetchVulnerabilityDatabase()
+        if not db:
+            return VulnerabilityCheckError.SummaryDbUnavailable, []
+        
+        fullDb = None
+        vulnerablePackages = frozenset(db.keys())
+        vulnerabilities = []            # TODO: fill this list
+        
+        for package in packages:
+            # normalize the package name, the safety-db is converting
+            # underscores to dashes and uses lowercase
+            name = package.name.replace("_", "-").lower()
+        
+            if name in vulnerablePackages:
+                # we have a candidate here, build the spec set
+                for specifier in db[name]:
+                    specifierSet = SpecifierSet(specifiers=specifier)
+                    if specifierSet.contains(package.version):
+                        if not fullDb:
+                            fullDb = self.__fetchVulnerabilityDatabase(
+                                full=True)
+                        for data in self.__getVulnerabilities(
+                            package=name, specifier=specifier, db=fullDb
+                        ):
+                            vulnarabilityId = (
+                                data.get("id").replace("pyup.io-", "")
+                            )
+                            cveId = data.get("cve")
+                            if cveId:
+                                cveId = cveId.split(",", 1)[0].strip()
+        
+        return VulnerabilityCheckError.OK, vulnerabilities

eric ide

mercurial