src/eric7/PipInterface/PipVulnerabilityChecker.py

Sun, 03 Dec 2023 19:46:34 +0100

author
Detlev Offenbach <detlev@die-offenbachs.de>
date
Sun, 03 Dec 2023 19:46:34 +0100
branch
eric7
changeset 10373
093dcebe5ecb
parent 10180
3a595df36c9a
child 10428
a071d4065202
permissions
-rw-r--r--

Corrected some uses of dict.keys(), dict.values() and dict.items().

# -*- coding: utf-8 -*-

# Copyright (c) 2022 - 2023 Detlev Offenbach <detlev@die-offenbachs.de>
#

"""
Module implementing a Python package vulnerability checker.

The vulnerability data is provided by the open Python vulnerability database
<a href="https://github.com/pyupio/safety-db">Safety DB</a>.
"""

import collections
import contextlib
import enum
import json
import os
import time

from dataclasses import dataclass

from packaging.specifiers import SpecifierSet
from PyQt6.QtCore import QCoreApplication, QObject, QThread, QUrl
from PyQt6.QtNetwork import QNetworkReply, QNetworkRequest

from eric7 import Globals, Preferences
from eric7.EricWidgets import EricMessageBox


@dataclass
class Package:
    """
    Class containing the package data.
    """

    name: str  # package name
    version: str  # version


@dataclass
class Vulnerability:
    """
    Class containing the vulnerability data.
    """

    name: str  # package name
    spec: dict  # package specification record
    version: str  # package version
    cve: str  # CVE ID
    advisory: str  # CVE advisory text
    vulnerabilityId: str  # vulnerability ID


class VulnerabilityCheckError(enum.Enum):
    """
    Class defining various vulnerability check error states.
    """

    OK = 0
    SummaryDbUnavailable = 1
    FullDbUnavailable = 2


class PipVulnerabilityChecker(QObject):
    """
    Class implementing a Python package vulnerability checker.
    """

    FullDbFile = "insecure_full.json"
    SummaryDbFile = "insecure.json"

    def __init__(self, pip, parent=None):
        """
        Constructor

        @param pip reference to the global pip interface
        @type Pip
        @param parent reference to the parent widget (defaults to None)
        @type QWidget (optional)
        """
        super().__init__(parent)

        self.__pip = pip

        securityDir = os.path.join(Globals.getConfigDir(), "security")
        os.makedirs(securityDir, mode=0o700, exist_ok=True)
        self.__cacheFile = os.path.join(securityDir, "vulnerability_cache.json")
        if not os.path.exists(self.__cacheFile):
            self.__createCacheFile()

    def __createCacheFile(self):
        """
        Private method to create the cache file.

        The cache file has the following structure.
        {
          "insecure.json": {
              "cachedAt": 12345678
              "db": {}
          },
          "insecure_full.json": {
              "cachedAt": 12345678
              "db": {}
          },
        }
        """
        structure = {
            "insecure.json": {
                "cachedAt": 0,
                "db": {},
            },
            "insecure_full.json": {
                "cachedAt": 0,
                "db": {},
            },
        }
        with open(self.__cacheFile, "w") as f:
            json.dump(structure, f, indent=2)

    def __getDataFromCache(self, dbName):
        """
        Private method to get the vulnerability database from the cache.

        @param dbName name of the vulnerability database
        @type str
        @return dictionary containing the requested vulnerability data
        @rtype dict
        """
        if os.path.exists(self.__cacheFile):
            with open(self.__cacheFile, "r") as f:  # noqa: Y117
                with contextlib.suppress(json.JSONDecodeError, OSError):
                    cachedData = json.load(f)
                    if dbName in cachedData and "cachedAt" in cachedData[dbName]:
                        cacheValidPeriod = Preferences.getPip(
                            "VulnerabilityDbCacheValidity"
                        )
                        if (
                            cachedData[dbName]["cachedAt"] + cacheValidPeriod
                            > time.time()
                        ):
                            return cachedData[dbName]["db"]

        return {}

    def __writeDataToCache(self, dbName, data):
        """
        Private method to write the vulnerability data for a database to the
        cache.

        @param dbName name of the vulnerability database
        @type str
        @param data dictionary containing the vulnerability data
        @type dict
        """
        if not os.path.exists(self.__cacheFile):
            self.__createCacheFile()

        with open(self.__cacheFile, "r") as f:
            try:
                cache = json.load(f)
            except json.JSONDecodeError:
                cache = {}

        cache[dbName] = {
            "cachedAt": time.time(),
            "db": data,
        }
        with open(self.__cacheFile, "w") as f:
            json.dump(cache, f, indent=2)

    def __fetchVulnerabilityDatabase(self, full=False, forceUpdate=False):
        """
        Private method to get the data of the vulnerability database.

        If the cached data is still valid, this data will be used.
        Otherwise a copy of the requested database will be downloaded
        and cached.

        @param full flag indicating to get the database containing the full
            data set (defaults to False)
        @type bool (optional)
        @param forceUpdate flag indicating an update of the cache is required
            (defaults to False)
        @type bool (optional)
        @return dictionary containing the vulnerability data (full data set or
            just package name and version specifier)
        """
        dbName = (
            PipVulnerabilityChecker.FullDbFile
            if full
            else PipVulnerabilityChecker.SummaryDbFile
        )

        if not forceUpdate:
            cachedData = self.__getDataFromCache(dbName)
            if cachedData:
                return cachedData

        url = Preferences.getPip("VulnerabilityDbMirror") + dbName
        request = QNetworkRequest(QUrl(url))
        reply = self.__pip.getNetworkAccessManager().get(request)
        while not reply.isFinished():
            QCoreApplication.processEvents()
            QThread.msleep(100)

        reply.deleteLater()
        if reply.error() == QNetworkReply.NetworkError.NoError:
            data = str(reply.readAll(), Preferences.getSystem("IOEncoding"), "replace")
            with contextlib.suppress(json.JSONDecodeError):
                data = json.loads(data)
                self.__writeDataToCache(dbName, data)
                return data

        EricMessageBox.critical(
            None,
            self.tr("Fetching Vulnerability Database"),
            self.tr(
                """<p>The vulnerability database <b>{0}</b> could not"""
                """ be loaded from <b>{1}</b>.</p><p>The vulnerability"""
                """ check is not available.</p>"""
            ).format(dbName, Preferences.getPip("VulnerabilityDbMirror")),
        )
        return {}

    def __getVulnerabilities(self, package, specifier, db):
        """
        Private method to get the vulnerabilities for a package.

        @param package name of the package
        @type str
        @param specifier package specifier
        @type Specifier
        @param db vulnerability data
        @type dict
        @yield dictionary containing the vulnerability data for the package
        @ytype dict
        """
        for entry in db[package]:
            for entrySpec in entry["specs"]:
                if entrySpec == specifier:
                    yield entry

    def check(self, packages):
        """
        Public method to check the given packages for vulnerabilities.

        @param packages list of packages
        @type Package
        @return tuple containing an error status and a dictionary containing
            detected vulnerable packages keyed by package name
        @rtype tuple of (VulnerabilityCheckError, list of Vulnerability)
        """
        db = self.__fetchVulnerabilityDatabase()
        if not db:
            return VulnerabilityCheckError.SummaryDbUnavailable, []

        fullDb = None
        vulnerablePackages = frozenset(db)
        vulnerabilities = collections.defaultdict(list)

        for package in packages:
            # normalize the package name, the safety-db is converting
            # underscores to dashes and uses lowercase
            name = package.name.replace("_", "-").lower()

            if name in vulnerablePackages:
                # we have a candidate here, build the spec set
                for specifier in db[name]:
                    specifierSet = SpecifierSet(specifiers=specifier)
                    if specifierSet.contains(package.version):
                        if not fullDb:
                            fullDb = self.__fetchVulnerabilityDatabase(full=True)
                        for data in self.__getVulnerabilities(
                            package=name, specifier=specifier, db=fullDb
                        ):
                            vulnarabilityId = data.get("id").replace("pyup.io-", "")
                            cveId = data.get("cve", "")
                            if cveId:
                                cveId = cveId.split(",", 1)[0].strip()
                            vulnerabilities[package.name].append(
                                Vulnerability(
                                    name=name,
                                    spec=specifier,
                                    version=package.version,
                                    cve=cveId,
                                    advisory=data.get("advisory", ""),
                                    vulnerabilityId=vulnarabilityId,
                                )
                            )

        return VulnerabilityCheckError.OK, vulnerabilities

    def updateVulnerabilityDb(self):
        """
        Public method to update the cache of the vulnerability databases.
        """
        self.__fetchVulnerabilityDatabase(full=False, forceUpdate=True)
        self.__fetchVulnerabilityDatabase(full=True, forceUpdate=True)

eric ide

mercurial