--- a/WebBrowser/SafeBrowsing/SafeBrowsingAPIClient.py Mon Jul 17 19:58:37 2017 +0200 +++ b/WebBrowser/SafeBrowsing/SafeBrowsingAPIClient.py Tue Jul 18 19:33:46 2017 +0200 @@ -15,9 +15,10 @@ import json import random +import base64 from PyQt5.QtCore import pyqtSlot, pyqtSignal, QObject, QDateTime, QTimer, \ - QUrl + QUrl, QByteArray from PyQt5.QtNetwork import QNetworkRequest, QNetworkReply from WebBrowser.WebBrowserWindow import WebBrowserWindow @@ -26,6 +27,12 @@ class SafeBrowsingAPIClient(QObject): """ Class implementing the low level interface for Google Safe Browsing. + + @signal networkError(str) emitted to indicate a network error + @signal threatLists(list) emitted to publish the received threat list + @signal threatsUpdate(list) emitted to publish the received threats + update + @signal fullHashes(dict) emitted to publish the full hashes result """ ClientId = "eric6_API_client" ClientVersion = "1.0.0" @@ -34,9 +41,8 @@ networkError = pyqtSignal(str) threatLists = pyqtSignal(list) - - # threatListUpdates:fetch Content-Type: application/json POST - # fullHashes:find Content-Type: application/json POST + threatsUpdate = pyqtSignal(list) + fullHashes = pyqtSignal(dict) def __init__(self, apiKey, fairUse=True, parent=None): """ @@ -53,20 +59,28 @@ self.__fairUse = fairUse self.__nextRequestNoSoonerThan = QDateTime() - self.__replies = [] self.__failCount = 0 + + # get threat lists + self.__threatListsReply = None + + # threats lists updates + self.__threatsUpdatesRequest = None + self.__threatsUpdateReply = None + + # full hashes + self.__fullHashesRequest = None + self.__fullHashesReply = None def getThreatLists(self): """ Public method to retrieve all available threat lists. - - @return threat lists - @rtype list of dictionaries """ url = QUrl(self.GsbUrlTemplate.format("threatLists", self.__apiKey)) req = QNetworkRequest(url) reply = WebBrowserWindow.networkManager().get(req) reply.finished.connect(self.__threatListsReceived) + self.__threatListsReply = reply @pyqtSlot() def __threatListsReceived(self): @@ -74,19 +88,169 @@ Private slot handling the threat lists. """ reply = self.sender() - result, hasError = self.__extractData(reply) - if hasError: - # reschedule - self.networkError.emit(reply.errorString()) - self.__reschedule(reply.error(), self.getThreatLists) + if reply is self.__threatListsReply: + self.__threatListsReply = None + result, hasError = self.__extractData(reply) + if hasError: + # reschedule + self.networkError.emit(reply.errorString()) + self.__reschedule(reply.error(), self.getThreatLists) + else: + self.threatLists.emit(result["threatLists"]) + + reply.deleteLater() + + def getThreatsUpdate(self, clientState=None): + """ + Public method to fetch hash prefix updates for the given threat list. + + @param clientState dictionary of client states with keys like + (threatType, platformType, threatEntryType) + @type dict + """ + if self.__threatsUpdateReply is not None: + # update is in progress + return + + if clientState is None: + if self.__threatsUpdatesRequest: + requestBody = self.__threatsUpdatesRequest + else: + return else: - self.__setWaitDuration(result.get("minimumWaitDuration")) - self.threatLists.emit(result["threatLists"]) - self.__failCount = 0 + requestBody = { + "client": { + "clientId": self.ClientId, + "clientVersion": self.ClientVersion, + }, + "listUpdateRequests": [], + } + + for (threatType, platformType, threatEntryType), currentState in \ + clientState.items(): + requestBody["listUpdateRequests"].append( + { + "threatType": threatType, + "platformType": platformType, + "threatEntryType": threatEntryType, + "state": currentState, + "constraints": { + "supportedCompressions": ["RAW"], + } + } + ) + + self.__threatsUpdatesRequest = requestBody + + data = QByteArray(json.dumps(requestBody).encode("utf-8")) + url = QUrl(self.GsbUrlTemplate.format("threatListUpdates:fetch", + self.__apiKey)) + req = QNetworkRequest(url) + req.setHeader(QNetworkRequest.ContentTypeHeader, "application/json") + reply = WebBrowserWindow.networkManager().post(req, data) + reply.finished.connect(self.__threatsUpdateReceived) + self.__threatsUpdateReply = reply + + @pyqtSlot() + def __threatsUpdateReceived(self): + """ + Private slot handling the threats update. + """ + reply = self.sender() + if reply is self.__threatsUpdateReply: + self.__threatsUpdateReply = None + result, hasError = self.__extractData(reply) + if hasError: + # reschedule + self.networkError.emit(reply.errorString()) + self.__reschedule(reply.error(), self.getThreatsUpdate) + else: + self.__threatsUpdatesRequest = None + self.threatsUpdate.emit(result["listUpdateResponses"]) + + reply.deleteLater() + + def getFullHashes(self, prefixes=None, clientState=None): + """ + Public method to find full hashes matching hash prefixes. + + @param prefixes list of hash prefixes to find + @type list of str (Python 2) or list of bytes (Python 3) + @param clientState dictionary of client states with keys like + (threatType, platformType, threatEntryType) + @type dict + """ + if self.__fullHashesReply is not None: + # full hash request in progress + return - if reply in self.__replies: - self.__replies.remove(reply) - reply.deleteLater() + if prefixes is None or clientState is None: + if self.__fullHashesRequest: + requestBody = self.__fullHashesRequest + else: + return + else: + requestBody = { + "client": { + "clientId": self.ClientId, + "clientVersion": self.ClientVersion, + }, + "clientStates": [], + "threatInfo": { + "threatTypes": [], + "platformTypes": [], + "threatEntryTypes": [], + "threatEntries": [], + }, + } + + for prefix in prefixes: + requestBody["threatInfo"]["threatEntries"].append( + {"hash": base64.b64encode(prefix).decode("ascii")}) + + for (threatType, platformType, threatEntryType), currentState in \ + clientState.items(): + requestBody["clientStates"].append(clientState) + if threatType not in requestBody["threatInfo"]["threatTypes"]: + requestBody["threatInfo"]["threatTypes"].append(threatType) + if platformType not in \ + requestBody["threatInfo"]["platformTypes"]: + requestBody["threatInfo"]["platformTypes"].append( + platformType) + if threatEntryType not in \ + requestBody["threatInfo"]["threatEntryTypes"]: + requestBody["threatInfo"]["threatEntryTypes"].append( + threatEntryType) + + self.__fullHashesRequest = requestBody + + data = QByteArray(json.dumps(requestBody).encode("utf-8")) + url = QUrl(self.GsbUrlTemplate.format("fullHashes:find", + self.__apiKey)) + req = QNetworkRequest(url) + req.setHeader(QNetworkRequest.ContentTypeHeader, "application/json") + reply = WebBrowserWindow.networkManager().post(req, data) + reply.finished.connect(self.__fullHashesReceived) + self.__fullHashesReply = reply + + @pyqtSlot() + def __fullHashesReceived(self): + """ + Private slot handling the full hashes reply. + """ + reply = self.sender() + if reply is self.__fullHashesReply: + self.__fullHashesReply = None + result, hasError = self.__extractData(reply) + if hasError: + # reschedule + self.networkError.emit(reply.errorString()) + self.__reschedule(reply.error(), self.getFullHashes) + else: + self.__fullHashesRequest = None + self.fullHashes.emit(result) + + reply.deleteLater() def __extractData(self, reply): """ @@ -100,7 +264,9 @@ if reply.error() != QNetworkReply.NoError: return None, True + self.__failCount = 0 result = json.loads(str(reply.readAll(), "utf-8")) + self.__setWaitDuration(result.get("minimumWaitDuration")) return result, False def __setWaitDuration(self, minimumWaitDuration):