WebBrowser/SafeBrowsing/SafeBrowsingAPIClient.py

branch
safe_browsing
changeset 5811
5358a3c7995f
parent 5809
5b53c17b7d93
child 5816
93c74269d59e
diff -r 5b53c17b7d93 -r 5358a3c7995f WebBrowser/SafeBrowsing/SafeBrowsingAPIClient.py
--- a/WebBrowser/SafeBrowsing/SafeBrowsingAPIClient.py	Mon Jul 17 19:58:37 2017 +0200
+++ b/WebBrowser/SafeBrowsing/SafeBrowsingAPIClient.py	Tue Jul 18 19:33:46 2017 +0200
@@ -15,9 +15,10 @@
 
 import json
 import random
+import base64
 
 from PyQt5.QtCore import pyqtSlot, pyqtSignal, QObject, QDateTime, QTimer, \
-    QUrl
+    QUrl, QByteArray
 from PyQt5.QtNetwork import QNetworkRequest, QNetworkReply
 
 from WebBrowser.WebBrowserWindow import WebBrowserWindow
@@ -26,6 +27,12 @@
 class SafeBrowsingAPIClient(QObject):
     """
     Class implementing the low level interface for Google Safe Browsing.
+    
+    @signal networkError(str) emitted to indicate a network error
+    @signal threatLists(list) emitted to publish the received threat list
+    @signal threatsUpdate(list) emitted to publish the received threats
+        update
+    @signal fullHashes(dict) emitted to publish the full hashes result
     """
     ClientId = "eric6_API_client"
     ClientVersion = "1.0.0"
@@ -34,9 +41,8 @@
     
     networkError = pyqtSignal(str)
     threatLists = pyqtSignal(list)
-    
-    # threatListUpdates:fetch   Content-Type: application/json      POST
-    # fullHashes:find           Content-Type: application/json      POST
+    threatsUpdate = pyqtSignal(list)
+    fullHashes = pyqtSignal(dict)
     
     def __init__(self, apiKey, fairUse=True, parent=None):
         """
@@ -53,20 +59,28 @@
         self.__fairUse = fairUse
         
         self.__nextRequestNoSoonerThan = QDateTime()
-        self.__replies = []
         self.__failCount = 0
+        
+        # get threat lists
+        self.__threatListsReply = None
+        
+        # threats lists updates
+        self.__threatsUpdatesRequest = None
+        self.__threatsUpdateReply = None
+        
+        # full hashes
+        self.__fullHashesRequest = None
+        self.__fullHashesReply = None
     
     def getThreatLists(self):
         """
         Public method to retrieve all available threat lists.
-        
-        @return threat lists
-        @rtype list of dictionaries
         """
         url = QUrl(self.GsbUrlTemplate.format("threatLists", self.__apiKey))
         req = QNetworkRequest(url)
         reply = WebBrowserWindow.networkManager().get(req)
         reply.finished.connect(self.__threatListsReceived)
+        self.__threatListsReply = reply
     
     @pyqtSlot()
     def __threatListsReceived(self):
@@ -74,19 +88,169 @@
         Private slot handling the threat lists.
         """
         reply = self.sender()
-        result, hasError = self.__extractData(reply)
-        if hasError:
-            # reschedule
-            self.networkError.emit(reply.errorString())
-            self.__reschedule(reply.error(), self.getThreatLists)
+        if reply is self.__threatListsReply:
+            self.__threatListsReply = None
+            result, hasError = self.__extractData(reply)
+            if hasError:
+                # reschedule
+                self.networkError.emit(reply.errorString())
+                self.__reschedule(reply.error(), self.getThreatLists)
+            else:
+                self.threatLists.emit(result["threatLists"])
+            
+            reply.deleteLater()
+    
+    def getThreatsUpdate(self, clientState=None):
+        """
+        Public method to fetch hash prefix updates for the given threat list.
+        
+        @param clientState dictionary of client states with keys like
+            (threatType, platformType, threatEntryType)
+        @type dict
+        """
+        if self.__threatsUpdateReply is not None:
+            # update is in progress
+            return
+        
+        if clientState is None:
+            if self.__threatsUpdatesRequest:
+                requestBody = self.__threatsUpdatesRequest
+            else:
+                return
         else:
-            self.__setWaitDuration(result.get("minimumWaitDuration"))
-            self.threatLists.emit(result["threatLists"])
-            self.__failCount = 0
+            requestBody = {
+                "client": {
+                    "clientId": self.ClientId,
+                    "clientVersion": self.ClientVersion,
+                },
+                "listUpdateRequests": [],
+            }
+            
+            for (threatType, platformType, threatEntryType), currentState in \
+                    clientState.items():
+                requestBody["listUpdateRequests"].append(
+                    {
+                        "threatType": threatType,
+                        "platformType": platformType,
+                        "threatEntryType": threatEntryType,
+                        "state": currentState,
+                        "constraints": {
+                            "supportedCompressions": ["RAW"],
+                        }
+                    }
+                )
+            
+            self.__threatsUpdatesRequest = requestBody
+        
+        data = QByteArray(json.dumps(requestBody).encode("utf-8"))
+        url = QUrl(self.GsbUrlTemplate.format("threatListUpdates:fetch",
+                                              self.__apiKey))
+        req = QNetworkRequest(url)
+        req.setHeader(QNetworkRequest.ContentTypeHeader, "application/json")
+        reply = WebBrowserWindow.networkManager().post(req, data)
+        reply.finished.connect(self.__threatsUpdateReceived)
+        self.__threatsUpdateReply = reply
+    
+    @pyqtSlot()
+    def __threatsUpdateReceived(self):
+        """
+        Private slot handling the threats update.
+        """
+        reply = self.sender()
+        if reply is self.__threatsUpdateReply:
+            self.__threatsUpdateReply = None
+            result, hasError = self.__extractData(reply)
+            if hasError:
+                # reschedule
+                self.networkError.emit(reply.errorString())
+                self.__reschedule(reply.error(), self.getThreatsUpdate)
+            else:
+                self.__threatsUpdatesRequest = None
+                self.threatsUpdate.emit(result["listUpdateResponses"])
+            
+            reply.deleteLater()
+    
+    def getFullHashes(self, prefixes=None, clientState=None):
+        """
+        Public method to find full hashes matching hash prefixes.
+        
+        @param prefixes list of hash prefixes to find
+        @type list of str (Python 2) or list of bytes (Python 3)
+        @param clientState dictionary of client states with keys like
+            (threatType, platformType, threatEntryType)
+        @type dict
+        """
+        if self.__fullHashesReply is not None:
+            # full hash request in progress
+            return
         
-        if reply in self.__replies:
-            self.__replies.remove(reply)
-        reply.deleteLater()
+        if prefixes is None or clientState is None:
+            if self.__fullHashesRequest:
+                requestBody = self.__fullHashesRequest
+            else:
+                return
+        else:
+            requestBody = {
+                "client": {
+                    "clientId": self.ClientId,
+                    "clientVersion": self.ClientVersion,
+                },
+                "clientStates": [],
+                "threatInfo": {
+                    "threatTypes": [],
+                    "platformTypes": [],
+                    "threatEntryTypes": [],
+                    "threatEntries": [],
+                },
+            }
+            
+            for prefix in prefixes:
+                requestBody["threatInfo"]["threatEntries"].append(
+                    {"hash": base64.b64encode(prefix).decode("ascii")})
+            
+            for (threatType, platformType, threatEntryType), currentState in \
+                    clientState.items():
+                requestBody["clientStates"].append(clientState)
+                if threatType not in requestBody["threatInfo"]["threatTypes"]:
+                    requestBody["threatInfo"]["threatTypes"].append(threatType)
+                if platformType not in \
+                        requestBody["threatInfo"]["platformTypes"]:
+                    requestBody["threatInfo"]["platformTypes"].append(
+                        platformType)
+                if threatEntryType not in \
+                        requestBody["threatInfo"]["threatEntryTypes"]:
+                    requestBody["threatInfo"]["threatEntryTypes"].append(
+                        threatEntryType)
+            
+            self.__fullHashesRequest = requestBody
+        
+        data = QByteArray(json.dumps(requestBody).encode("utf-8"))
+        url = QUrl(self.GsbUrlTemplate.format("fullHashes:find",
+                                              self.__apiKey))
+        req = QNetworkRequest(url)
+        req.setHeader(QNetworkRequest.ContentTypeHeader, "application/json")
+        reply = WebBrowserWindow.networkManager().post(req, data)
+        reply.finished.connect(self.__fullHashesReceived)
+        self.__fullHashesReply = reply
+    
+    @pyqtSlot()
+    def __fullHashesReceived(self):
+        """
+        Private slot handling the full hashes reply.
+        """
+        reply = self.sender()
+        if reply is self.__fullHashesReply:
+            self.__fullHashesReply = None
+            result, hasError = self.__extractData(reply)
+            if hasError:
+                # reschedule
+                self.networkError.emit(reply.errorString())
+                self.__reschedule(reply.error(), self.getFullHashes)
+            else:
+                self.__fullHashesRequest = None
+                self.fullHashes.emit(result)
+            
+            reply.deleteLater()
     
     def __extractData(self, reply):
         """
@@ -100,7 +264,9 @@
         if reply.error() != QNetworkReply.NoError:
             return None, True
         
+        self.__failCount = 0
         result = json.loads(str(reply.readAll(), "utf-8"))
+        self.__setWaitDuration(result.get("minimumWaitDuration"))
         return result, False
     
     def __setWaitDuration(self, minimumWaitDuration):

eric ide

mercurial