eric7/PipInterface/PipVulnerabilityChecker.py

Sun, 13 Mar 2022 19:59:03 +0100

author
Detlev Offenbach <detlev@die-offenbachs.de>
date
Sun, 13 Mar 2022 19:59:03 +0100
branch
eric7
changeset 8977
663521af48b2
child 8978
38c3ddf21537
permissions
-rw-r--r--

Started implementing a vulnerability checker based on the data of the Safety DB.

# -*- coding: utf-8 -*-

# Copyright (c) 2022 Detlev Offenbach <detlev@die-offenbachs.de>
#

"""
Module implementing a Python package vulnerability checker.

The vulnerability data is provided by the open Python vulnerability database
<a href="https://github.com/pyupio/safety-db">Safety DB</a>.
"""

import contextlib
import enum
import json
import os
import time
from collections import namedtuple
from dataclasses import dataclass

from packaging.specifiers import SpecifierSet

from PyQt6.QtCore import QCoreApplication, QObject, QThread, QUrl
from PyQt6.QtNetwork import QNetworkReply, QNetworkRequest

from EricWidgets import EricMessageBox

import Globals
import Preferences

Package = namedtuple("Package", ["name", "version"])


@dataclass
class Vulnerability:
    """
    Class containing the vulnerability data.
    """
    name: str               # package name
    spec: dict              # package specification record
    version: str            # package version
    cve: str                # CVE ID
    advisory: str           # CVE advisory text
    vulnerabilityId: str    # vulnerability ID


class VulnerabilityCheckError(enum.Enum):
    """
    Class defining various vulnerability check error states.
    """
    OK = 0
    SummaryDbUnavailable = 1
    FullDbUnavailable = 2


class PipVulnerabilityChecker(QObject):
    """
    Class implementing a Python package vulnerability checker.
    """
    def __init__(self, pip, parent=None):
        """
        Constructor
        
        @param pip reference to the global pip interface
        @type Pip
        @param parent reference to the parent widget (defaults to None)
        @type QWidget (optional)
        """
        super().__init__(parent)
        
        self.__pip = pip
        
        securityDir = os.path.join(Globals.getConfigDir(), "security")
        os.makedirs(securityDir, mode=0o700, exist_ok=True)
        self.__cacheFile = os.path.join(securityDir,
                                        "vulnerability_cache.json")
        if not os.path.exists(self.__cacheFile):
            self.__createCacheFile()
    
    def __createCacheFile(self):
        """
        Private method to create the cache file.
        
        The cache file has the following structure.
        {
          "insecure.json": {
              "cachedAt": 12345678
              "db": {}
          },
          "insecure_full.json": {
              "cachedAt": 12345678
              "db": {}
          },
        }
        """
        structure = {
            "insecure.json": {
                "cachedAt": 0,
                "db": {},
            },
            "insecure_full.json": {
                "cachedAt": 0,
                "db": {},
            },
        }
        with open(self.__cacheFile, "w") as f:
            json.dump(structure, f, indent=2)
    
    def __getDataFromCache(self, dbName):
        """
        Private method to get the vulnerability database from the cache.
        
        @param dbName name of the vulnerability database
        @type str
        @return dictionary containing the requested vulnerability data
        @rtype dict
        """
        if os.path.exists(self.__cacheFile):
            with open(self.__cacheFile, "r") as f:
                with contextlib.suppress(json.JSONDecodeError, OSError):
                    cachedData = json.load(f)
                    if (
                        dbName in cachedData and
                        "cachedAt" in cachedData[dbName]
                    ):
                        cacheValidPeriod = Preferences.getPip(
                            "VulnerabilityDbCacheValidity")
                        if (
                            cachedData[dbName]["cachedAt"] + cacheValidPeriod >
                            time.time()
                        ):
                            return cachedData[dbName]["db"]
        
        return {}
    
    def __writeDataToCache(self, dbName, data):
        """
        Private method to write the vulnerability data for a database to the
        cache.
        
        @param dbName name of the vulnerability database
        @type str
        @param data dictionary containing the vulnerability data
        @type dict
        """
        if not os.path.exists(self.__cacheFile):
            self.__createCacheFile()
        
        with open(self.__cacheFile, "r") as f:
            try:
                cache = json.load(f)
            except json.JSONDecodeError:
                cache = {}
        
        cache[dbName] = {
            "cachedAt": time.time(),
            "db": data,
        }
        with open(self.__cacheFile, "w") as f:
            json.dump(cache, f, indent=2)
    
    def __fetchVulnerabilityDatabase(self, full=False):
        """
        Private method to get the data of the vulnerability database.
        
        If the cached data is still valid, this data will be used.
        Otherwise a copy of the requested database will be downloaded
        and cached.
        
        @param full flag indicating to get the database containing the full
            data set (defaults to False)
        @type bool (optional)
        @return dictionary containing the vulnerability data (full data set or
            just package name and version specifier)
        """
        dbName = "insecure_full.json" if full else "insecure.json"
        
        cachedData = self.__getDataFromCache(dbName)
        if cachedData:
            return cachedData
        
        url = Preferences.getPip("VulnerabilityDbMirror") + dbName
        request = QNetworkRequest(QUrl(url))
        reply = self.__pip.getNetworkAccessManager().get(request)
        while not reply.isFinished():
            QCoreApplication.processEvents()
            QThread.msleep(100)
        
        reply.deleteLater()
        if reply.error() == QNetworkReply.NetworkError.NoError:
            data = str(reply.readAll(),
                       Preferences.getSystem("IOEncoding"),
                       'replace')
            with contextlib.suppress(json.JSONDecodeError):
                data = json.loads(data)
                self.__writeDataToCache(dbName, data)
                return data
        
        EricMessageBox.critical(
            None,
            self.tr("Fetching Vulnerability Database"),
            self.tr("""<p>The vulnerability database <b>{0}</b> could not"""
                    """ be loaded from <b>{1}</b>.</p><p>The vulnerability"""
                    """ check is not available.</p>""")
        )
        return {}
    
    def __getVulnerabilities(self, package, specifier, db):
        """
        Private method to get the vulnerabilities for a package.
        
        @param package name of the package
        @type str
        @param specifier package specifier
        @type Specifier
        @param db vulnerability data
        @type dict
        @yield dictionary containing the vulnerability data for the package
        @ytype dict
        """
        for entry in db[package]:
            for entrySpec in entry["specs"]:
                if entrySpec == specifier:
                    yield entry
    
    def check(self, packages):
        """
        Public method to check the given packages for vulnerabilities.
        
        @param packages list of packages
        @type Package
        @return tuple containing an error status and the list of vulnerable
            packages detected
        @rtype tuple of (VulnerabilityCheckError, list of Vulnerability)
        """
        db = self.__fetchVulnerabilityDatabase()
        if not db:
            return VulnerabilityCheckError.SummaryDbUnavailable, []
        
        fullDb = None
        vulnerablePackages = frozenset(db.keys())
        vulnerabilities = []            # TODO: fill this list
        
        for package in packages:
            # normalize the package name, the safety-db is converting
            # underscores to dashes and uses lowercase
            name = package.name.replace("_", "-").lower()
        
            if name in vulnerablePackages:
                # we have a candidate here, build the spec set
                for specifier in db[name]:
                    specifierSet = SpecifierSet(specifiers=specifier)
                    if specifierSet.contains(package.version):
                        if not fullDb:
                            fullDb = self.__fetchVulnerabilityDatabase(
                                full=True)
                        for data in self.__getVulnerabilities(
                            package=name, specifier=specifier, db=fullDb
                        ):
                            vulnarabilityId = (
                                data.get("id").replace("pyup.io-", "")
                            )
                            cveId = data.get("cve")
                            if cveId:
                                cveId = cveId.split(",", 1)[0].strip()
        
        return VulnerabilityCheckError.OK, vulnerabilities

eric ide

mercurial