src/eric7/WebBrowser/SafeBrowsing/SafeBrowsingAPIClient.py

branch
eric7
changeset 9221
bf71ee032bb4
parent 9209
b99e7fd55fd3
child 9413
80c06d472826
--- a/src/eric7/WebBrowser/SafeBrowsing/SafeBrowsingAPIClient.py	Wed Jul 13 11:16:20 2022 +0200
+++ b/src/eric7/WebBrowser/SafeBrowsing/SafeBrowsingAPIClient.py	Wed Jul 13 14:55:47 2022 +0200
@@ -11,8 +11,13 @@
 import base64
 
 from PyQt6.QtCore import (
-    pyqtSignal, QObject, QDateTime, QUrl, QByteArray, QCoreApplication,
-    QEventLoop
+    pyqtSignal,
+    QObject,
+    QDateTime,
+    QUrl,
+    QByteArray,
+    QCoreApplication,
+    QEventLoop,
 )
 from PyQt6.QtNetwork import QNetworkRequest, QNetworkReply
 
@@ -24,20 +29,21 @@
 class SafeBrowsingAPIClient(QObject):
     """
     Class implementing the low level interface for Google Safe Browsing.
-    
+
     @signal networkError(str) emitted to indicate a network error
     """
+
     ClientId = "eric7_API_client"
     ClientVersion = "2.0.0"
-    
+
     GsbUrlTemplate = "https://safebrowsing.googleapis.com/v4/{0}?key={1}"
-    
+
     networkError = pyqtSignal(str)
-    
+
     def __init__(self, apiKey, fairUse=True, parent=None):
         """
         Constructor
-        
+
         @param apiKey API key to be used
         @type str
         @param fairUse flag indicating to follow the fair use policy
@@ -46,33 +52,33 @@
         @type QObject
         """
         super().__init__(parent)
-        
+
         self.__apiKey = apiKey
         self.__fairUse = fairUse
-        
+
         self.__nextRequestNoSoonerThan = QDateTime()
         self.__failCount = 0
-        
+
         self.__lookupApiCache = {}
         # Temporary cache used by the lookup API (v4)
         # key: URL as string
         # value: dictionary with these entries:
         #   "validUntil": (QDateTime)
         #   "threatInfo": (list of ThreatList)
-    
+
     def setApiKey(self, apiKey):
         """
         Public method to set the API key.
-        
+
         @param apiKey API key to be set
         @type str
         """
         self.__apiKey = apiKey
-    
+
     def getThreatLists(self):
         """
         Public method to retrieve all available threat lists.
-        
+
         @return tuple containing list of threat lists and an error message
         @rtype tuple of (list of dict containing 'threatType', 'platformType'
             and 'threatEntryType', bool)
@@ -80,12 +86,11 @@
         url = QUrl(self.GsbUrlTemplate.format("threatLists", self.__apiKey))
         req = QNetworkRequest(url)
         reply = WebBrowserWindow.networkManager().get(req)
-        
+
         while reply.isRunning():
-            QCoreApplication.processEvents(
-                QEventLoop.ProcessEventsFlag.AllEvents, 200)
+            QCoreApplication.processEvents(QEventLoop.ProcessEventsFlag.AllEvents, 200)
             # max. 200 ms processing
-        
+
         res = None
         error = ""
         if reply.error() != QNetworkReply.NetworkError.NoError:
@@ -94,18 +99,18 @@
         else:
             result = self.__extractData(reply)
             res = result["threatLists"]
-        
+
         reply.deleteLater()
         return res, error
-    
+
     #######################################################################
     ## Methods below implement the 'Update API (v4)'
     #######################################################################
-    
+
     def getThreatsUpdate(self, clientStates):
         """
         Public method to fetch hash prefix updates for the given threat list.
-        
+
         @param clientStates dictionary of client states with keys like
             (threatType, platformType, threatEntryType)
         @type dict
@@ -120,10 +125,12 @@
             },
             "listUpdateRequests": [],
         }
-        
-        for (threatType, platformType, threatEntryType), currentState in (
-            clientStates.items()
-        ):
+
+        for (
+            threatType,
+            platformType,
+            threatEntryType,
+        ), currentState in clientStates.items():
             requestBody["listUpdateRequests"].append(
                 {
                     "threatType": threatType,
@@ -132,23 +139,22 @@
                     "state": currentState,
                     "constraints": {
                         "supportedCompressions": ["RAW"],
-                    }
+                    },
                 }
             )
-        
+
         data = QByteArray(json.dumps(requestBody).encode("utf-8"))
-        url = QUrl(self.GsbUrlTemplate.format("threatListUpdates:fetch",
-                                              self.__apiKey))
+        url = QUrl(self.GsbUrlTemplate.format("threatListUpdates:fetch", self.__apiKey))
         req = QNetworkRequest(url)
-        req.setHeader(QNetworkRequest.KnownHeaders.ContentTypeHeader,
-                      "application/json")
+        req.setHeader(
+            QNetworkRequest.KnownHeaders.ContentTypeHeader, "application/json"
+        )
         reply = WebBrowserWindow.networkManager().post(req, data)
-        
+
         while reply.isRunning():
-            QCoreApplication.processEvents(
-                QEventLoop.ProcessEventsFlag.AllEvents, 200)
+            QCoreApplication.processEvents(QEventLoop.ProcessEventsFlag.AllEvents, 200)
             # max. 200 ms processing
-        
+
         res = None
         error = ""
         if reply.error() != QNetworkReply.NetworkError.NoError:
@@ -157,14 +163,14 @@
         else:
             result = self.__extractData(reply)
             res = result["listUpdateResponses"]
-        
+
         reply.deleteLater()
         return res, error
-    
+
     def getFullHashes(self, prefixes, clientState):
         """
         Public method to find full hashes matching hash prefixes.
-        
+
         @param prefixes list of hash prefixes to find
         @type list of bytes
         @param clientState dictionary of client states with keys like
@@ -187,56 +193,50 @@
                 "threatEntries": [],
             },
         }
-        
+
         for prefix in prefixes:
             requestBody["threatInfo"]["threatEntries"].append(
-                {"hash": base64.b64encode(prefix).decode("ascii")})
-        
-        for (threatType, platformType, threatEntryType), currentState in (
-            clientState.items()
-        ):
+                {"hash": base64.b64encode(prefix).decode("ascii")}
+            )
+
+        for (
+            threatType,
+            platformType,
+            threatEntryType,
+        ), currentState in clientState.items():
             requestBody["clientStates"].append(currentState)
             if threatType not in requestBody["threatInfo"]["threatTypes"]:
                 requestBody["threatInfo"]["threatTypes"].append(threatType)
-            if (
-                platformType not in
-                    requestBody["threatInfo"]["platformTypes"]
-            ):
-                requestBody["threatInfo"]["platformTypes"].append(
-                    platformType)
-            if (
-                threatEntryType not in
-                    requestBody["threatInfo"]["threatEntryTypes"]
-            ):
-                requestBody["threatInfo"]["threatEntryTypes"].append(
-                    threatEntryType)
-        
+            if platformType not in requestBody["threatInfo"]["platformTypes"]:
+                requestBody["threatInfo"]["platformTypes"].append(platformType)
+            if threatEntryType not in requestBody["threatInfo"]["threatEntryTypes"]:
+                requestBody["threatInfo"]["threatEntryTypes"].append(threatEntryType)
+
         data = QByteArray(json.dumps(requestBody).encode("utf-8"))
-        url = QUrl(self.GsbUrlTemplate.format("fullHashes:find",
-                                              self.__apiKey))
+        url = QUrl(self.GsbUrlTemplate.format("fullHashes:find", self.__apiKey))
         req = QNetworkRequest(url)
-        req.setHeader(QNetworkRequest.KnownHeaders.ContentTypeHeader,
-                      "application/json")
+        req.setHeader(
+            QNetworkRequest.KnownHeaders.ContentTypeHeader, "application/json"
+        )
         reply = WebBrowserWindow.networkManager().post(req, data)
-        
+
         while reply.isRunning():
-            QCoreApplication.processEvents(
-                QEventLoop.ProcessEventsFlag.AllEvents, 200)
+            QCoreApplication.processEvents(QEventLoop.ProcessEventsFlag.AllEvents, 200)
             # max. 200 ms processing
-        
+
         res = []
         if reply.error() != QNetworkReply.NetworkError.NoError:
             self.networkError.emit(reply.errorString())
         else:
             res = self.__extractData(reply)
-        
+
         reply.deleteLater()
         return res
-    
+
     def __extractData(self, reply):
         """
         Private method to extract the data of a network reply.
-        
+
         @param reply reference to the network reply object
         @type QNetworkReply
         @return extracted data
@@ -245,11 +245,11 @@
         result = json.loads(str(reply.readAll(), "utf-8"))
         self.__setWaitDuration(result.get("minimumWaitDuration"))
         return result
-    
+
     def __setWaitDuration(self, minimumWaitDuration):
         """
         Private method to set the minimum wait duration.
-        
+
         @param minimumWaitDuration duration to be set
         @type str
         """
@@ -257,39 +257,39 @@
             self.__nextRequestNoSoonerThan = QDateTime()
         else:
             waitDuration = int(float(minimumWaitDuration.rstrip("s")))
-            self.__nextRequestNoSoonerThan = (
-                QDateTime.currentDateTime().addSecs(waitDuration)
+            self.__nextRequestNoSoonerThan = QDateTime.currentDateTime().addSecs(
+                waitDuration
             )
-    
+
     def fairUseDelayExpired(self):
         """
         Public method to check, if the fair use wait period has expired.
-        
+
         @return flag indicating expiration
         @rtype bool
         """
         return (
-            self.__fairUse and
-            QDateTime.currentDateTime() >= self.__nextRequestNoSoonerThan
+            self.__fairUse
+            and QDateTime.currentDateTime() >= self.__nextRequestNoSoonerThan
         ) or not self.__fairUse
-    
+
     def getFairUseDelayExpirationDateTime(self):
         """
         Public method to get the date and time the fair use delay will expire.
-        
+
         @return fair use delay expiration date and time
         @rtype QDateTime
         """
         return self.__nextRequestNoSoonerThan
-    
+
     #######################################################################
     ## Methods below implement the 'Lookup API (v4)'
     #######################################################################
-    
+
     def lookupUrl(self, url, platforms):
         """
         Public method to send an URL to Google for checking.
-        
+
         @param url URL to be checked
         @type QUrl
         @param platforms list of platform types to check against
@@ -299,26 +299,26 @@
         @rtype tuple of (list of ThreatList, str)
         """
         error = ""
-        
+
         # sanitize the URL by removing user info and query data
         url = url.adjusted(
-            QUrl.UrlFormattingOption.RemoveUserInfo |
-            QUrl.UrlFormattingOption.RemoveQuery |
-            QUrl.UrlFormattingOption.RemoveFragment
+            QUrl.UrlFormattingOption.RemoveUserInfo
+            | QUrl.UrlFormattingOption.RemoveQuery
+            | QUrl.UrlFormattingOption.RemoveFragment
         )
         urlStr = url.toString()
-        
+
         # check the local cache first
         if urlStr in self.__lookupApiCache:
             if (
-                self.__lookupApiCache[urlStr]["validUntil"] >
-                    QDateTime.currentDateTime()
+                self.__lookupApiCache[urlStr]["validUntil"]
+                > QDateTime.currentDateTime()
             ):
                 # cached entry is still valid
                 return self.__lookupApiCache[urlStr]["threatInfo"], error
             else:
                 del self.__lookupApiCache[urlStr]
-        
+
         # no valid entry found, ask the safe browsing server
         requestBody = {
             "client": {
@@ -328,27 +328,25 @@
             "threatInfo": {
                 "threatTypes": SafeBrowsingAPIClient.definedThreatTypes(),
                 "platformTypes": platforms,
-                "threatEntryTypes":
-                    SafeBrowsingAPIClient.definedThreatEntryTypes(),
+                "threatEntryTypes": SafeBrowsingAPIClient.definedThreatEntryTypes(),
                 "threatEntries": [
                     {"url": urlStr},
                 ],
             },
         }
-        
+
         data = QByteArray(json.dumps(requestBody).encode("utf-8"))
-        url = QUrl(self.GsbUrlTemplate.format("threatMatches:find",
-                                              self.__apiKey))
+        url = QUrl(self.GsbUrlTemplate.format("threatMatches:find", self.__apiKey))
         req = QNetworkRequest(url)
-        req.setHeader(QNetworkRequest.KnownHeaders.ContentTypeHeader,
-                      "application/json")
+        req.setHeader(
+            QNetworkRequest.KnownHeaders.ContentTypeHeader, "application/json"
+        )
         reply = WebBrowserWindow.networkManager().post(req, data)
-        
+
         while reply.isRunning():
-            QCoreApplication.processEvents(
-                QEventLoop.ProcessEventsFlag.AllEvents, 200)
+            QCoreApplication.processEvents(QEventLoop.ProcessEventsFlag.AllEvents, 200)
             # max. 200 ms processing
-        
+
         threats = []
         if reply.error() != QNetworkReply.NetworkError.NoError:
             error = reply.errorString()
@@ -366,30 +364,29 @@
                     threats.append(threatInfo)
                     if "cacheDuration" in match:
                         cacheDurationSec = int(
-                            match["cacheDuration"].strip().rstrip("s")
-                            .split(".")[0])
+                            match["cacheDuration"].strip().rstrip("s").split(".")[0]
+                        )
                         if cacheDurationSec > cacheDuration:
                             cacheDuration = cacheDurationSec
                 if cacheDuration > 0 and bool(threats):
-                    validUntil = QDateTime.currentDateTime().addSecs(
-                        cacheDuration)
+                    validUntil = QDateTime.currentDateTime().addSecs(cacheDuration)
                     self.__lookupApiCache[urlStr] = {
                         "validUntil": validUntil,
-                        "threatInfo": threats
+                        "threatInfo": threats,
                     }
-        
+
         reply.deleteLater()
         return threats, error
-    
+
     #######################################################################
     ## Methods below implement global (class wide) functionality
     #######################################################################
-    
+
     @classmethod
     def getThreatMessage(cls, threatType):
         """
         Class method to get a warning message for the given threat type.
-        
+
         @param threatType threat type to get the message for
         @type str
         @return threat message
@@ -402,7 +399,8 @@
                 "<h3>Malware Warning</h3>"
                 "<p>The web site you are about to visit may try to install"
                 " harmful programs on your computer in order to steal or"
-                " destroy your data.</p>")
+                " destroy your data.</p>",
+            )
         elif threatType == "social_engineering":
             msg = QCoreApplication.translate(
                 "SafeBrowsingAPI",
@@ -410,41 +408,46 @@
                 "<p>The web site you are about to visit may try to trick you"
                 " into doing something dangerous online, such as revealing"
                 " passwords or personal information, usually through a fake"
-                " website.</p>")
+                " website.</p>",
+            )
         elif threatType == "unwanted_software":
             msg = QCoreApplication.translate(
                 "SafeBrowsingAPI",
                 "<h3>Unwanted Software Warning</h3>"
                 "<p>The software you are about to download may negatively"
-                " affect your browsing or computing experience.</p>")
+                " affect your browsing or computing experience.</p>",
+            )
         elif threatType == "potentially_harmful_application":
             msg = QCoreApplication.translate(
                 "SafeBrowsingAPI",
                 "<h3>Potentially Harmful Application</h3>"
                 "<p>The web site you are about to visit may try to trick you"
                 " into installing applications, that may negatively affect"
-                " your browsing experience.</p>")
+                " your browsing experience.</p>",
+            )
         elif threatType == "malicious_binary":
             msg = QCoreApplication.translate(
                 "SafeBrowsingAPI",
                 "<h3>Malicious Binary Warning</h3>"
                 "<p>The software you are about to download may be harmful"
-                " to your computer.</p>")
+                " to your computer.</p>",
+            )
         else:
             # unknow threat
             msg = QCoreApplication.translate(
                 "SafeBrowsingAPI",
                 "<h3>Unknown Threat Warning</h3>"
                 "<p>The web site you are about to visit was found in the Safe"
-                " Browsing Database but was not classified yet.</p>")
-        
+                " Browsing Database but was not classified yet.</p>",
+            )
+
         return msg
-    
+
     @classmethod
     def getThreatType(cls, threatType):
         """
         Class method to get a display string for a given threat type.
-        
+
         @param threatType threat type to get display string for
         @type str
         @return display string
@@ -452,44 +455,49 @@
         """
         threatType = threatType.lower()
         if threatType == "malware":
-            displayString = QCoreApplication.translate(
-                "SafeBrowsingAPI", "Malware")
+            displayString = QCoreApplication.translate("SafeBrowsingAPI", "Malware")
         elif threatType == "social_engineering":
-            displayString = QCoreApplication.translate(
-                "SafeBrowsingAPI", "Phishing")
+            displayString = QCoreApplication.translate("SafeBrowsingAPI", "Phishing")
         elif threatType == "unwanted_software":
             displayString = QCoreApplication.translate(
-                "SafeBrowsingAPI", "Unwanted Software")
+                "SafeBrowsingAPI", "Unwanted Software"
+            )
         elif threatType == "potentially_harmful_application":
             displayString = QCoreApplication.translate(
-                "SafeBrowsingAPI", "Harmful Application")
+                "SafeBrowsingAPI", "Harmful Application"
+            )
         elif threatType == "malcious_binary":
             displayString = QCoreApplication.translate(
-                "SafeBrowsingAPI", "Malicious Binary")
+                "SafeBrowsingAPI", "Malicious Binary"
+            )
         else:
             displayString = QCoreApplication.translate(
-                "SafeBrowsingAPI", "Unknown Threat")
-        
+                "SafeBrowsingAPI", "Unknown Threat"
+            )
+
         return displayString
-    
+
     @classmethod
     def definedThreatTypes(cls):
         """
         Class method to get all threat types defined in API v4.
-        
+
         @return list of defined threat types
         @rtype list of str
         """
         return [
-            "THREAT_TYPE_UNSPECIFIED", "MALWARE", "SOCIAL_ENGINEERING",
-            "UNWANTED_SOFTWARE", "POTENTIALLY_HARMFUL_APPLICATION",
+            "THREAT_TYPE_UNSPECIFIED",
+            "MALWARE",
+            "SOCIAL_ENGINEERING",
+            "UNWANTED_SOFTWARE",
+            "POTENTIALLY_HARMFUL_APPLICATION",
         ]
-    
+
     @classmethod
     def getThreatEntryString(cls, threatEntry):
         """
         Class method to get the threat entry string.
-        
+
         @param threatEntry threat entry type as defined in the v4 API
         @type str
         @return threat entry string
@@ -498,29 +506,29 @@
         if threatEntry == "URL":
             return "URL"
         elif threatEntry == "EXECUTABLE":
-            return QCoreApplication.translate(
-                "SafeBrowsingAPI", "executable program")
+            return QCoreApplication.translate("SafeBrowsingAPI", "executable program")
         else:
-            return QCoreApplication.translate(
-                "SafeBrowsingAPI", "unknown type")
-    
+            return QCoreApplication.translate("SafeBrowsingAPI", "unknown type")
+
     @classmethod
     def definedThreatEntryTypes(cls):
         """
         Class method to get all threat entry types defined in API v4.
-        
+
         @return list of all defined threat entry types
         @rtype list of str
         """
         return [
-            "THREAT_ENTRY_TYPE_UNSPECIFIED", "URL", "EXECUTABLE",
+            "THREAT_ENTRY_TYPE_UNSPECIFIED",
+            "URL",
+            "EXECUTABLE",
         ]
-    
+
     @classmethod
     def getPlatformString(cls, platformType):
         """
         Class method to get the platform string for a given platform type.
-        
+
         @param platformType platform type as defined in the v4 API
         @type str
         @return platform string
@@ -536,22 +544,21 @@
         }
         if platformType in platformStrings:
             return platformStrings[platformType]
-        
+
         if platformType == "ANY_PLATFORM":
-            return QCoreApplication.translate(
-                "SafeBrowsingAPI", "any defined platform")
+            return QCoreApplication.translate("SafeBrowsingAPI", "any defined platform")
         elif platformType == "ALL_PLATFORMS":
             return QCoreApplication.translate(
-                "SafeBrowsingAPI", "all defined platforms")
+                "SafeBrowsingAPI", "all defined platforms"
+            )
         else:
-            return QCoreApplication.translate(
-                "SafeBrowsingAPI", "unknown platform")
-    
+            return QCoreApplication.translate("SafeBrowsingAPI", "unknown platform")
+
     @classmethod
     def getPlatformTypes(cls, platform):
         """
         Class method to get the platform types for a given platform.
-        
+
         @param platform platform string
         @type str (one of 'linux', 'windows', 'macos')
         @return list of platform types as defined in the v4 API for the
@@ -560,10 +567,10 @@
         @exception ValueError raised to indicate an invalid platform string
         """
         platform = platform.lower()
-        
+
         if platform not in ("linux", "windows", "macos"):
             raise ValueError("Unsupported platform")
-        
+
         platformTypes = ["ANY_PLATFORM", "ALL_PLATFORMS"]
         if platform == "linux":
             platformTypes.append("LINUX")
@@ -571,18 +578,25 @@
             platformTypes.append("WINDOWS")
         else:
             platformTypes.append("OSX")
-        
+
         return platformTypes
-    
+
     @classmethod
     def definedPlatformTypes(cls):
         """
         Class method to get all platform types defined in API v4.
-        
+
         @return list of all defined platform types
         @rtype list of str
         """
         return [
-            "PLATFORM_TYPE_UNSPECIFIED", "WINDOWS", "LINUX", "ANDROID", "OSX",
-            "IOS", "ANY_PLATFORM", "ALL_PLATFORMS", "CHROME",
+            "PLATFORM_TYPE_UNSPECIFIED",
+            "WINDOWS",
+            "LINUX",
+            "ANDROID",
+            "OSX",
+            "IOS",
+            "ANY_PLATFORM",
+            "ALL_PLATFORMS",
+            "CHROME",
         ]

eric ide

mercurial