WebBrowser/SafeBrowsing/SafeBrowsingCache.py

branch
safe_browsing
changeset 5817
a5f6c9128500
child 5818
cae9956be67e
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/WebBrowser/SafeBrowsing/SafeBrowsingCache.py	Mon Jul 24 18:40:07 2017 +0200
@@ -0,0 +1,300 @@
+# -*- coding: utf-8 -*-
+
+# Copyright (c) 2017 Detlev Offenbach <detlev@die-offenbachs.de>
+#
+
+"""
+Module implementing a cache for Google Safe Browsing.
+"""
+
+#
+# Some part of this code were ported from gglsbl.storage and adapted
+# to QtSql.
+#
+# https://github.com/afilipovich/gglsbl
+#
+
+from __future__ import unicode_literals
+
+import os
+
+from PyQt5.QtCore import QObject
+from PyQt5.QtSql import QSqlDatabase, QSqlQuery
+
+
+class ThreatList(object):
+    """
+    Class implementing the threat list info.
+    """
+    def __init__(self, threatType, platformType, threatEntryType):
+        """
+        Constructor
+        
+        @param threatType threat type
+        @type str
+        @param platformType platform type
+        @type str
+        @param threatEntryType threat entry type
+        @type str
+        """
+        self.__threatType = threatType
+        self.__platformType = platformType
+        self.__threatEntryType = threatEntryType
+
+    @classmethod
+    def fromApiEntry(cls, entry):
+        """
+        Class method to instantiate a threat list given a threat list entry
+        dictionary.
+        
+        @param entry threat list entry dictionary
+        @type dict
+        @return instantiated object
+        @rtype ThreatList
+        """
+        return cls(entry['threatType'], entry['platformType'],
+                   entry['threatEntryType'])
+
+    def asTuple(self):
+        """
+        Public method to convert the object to a tuple.
+        
+        @return tuple containing the threat list info
+        @rtype tuple of (str, str, str)
+        """
+        return (self.__threatType, self.platformType, self.__threatEntryType)
+
+    def __repr__(self):
+        """
+        Special method to generate a printable representation.
+        
+        @return printable representation
+        @rtype str
+        """
+        return '/'.join(self.asTuple())
+
+
+class SafeBrowsingCache(QObject):
+    """
+    Class implementing a cache for Google Safe Browsing.
+    """
+    
+    create_threat_list_stmt = """
+        CREATE TABLE threat_list
+        (threat_type character varying(128) NOT NULL,
+         platform_type character varying(128) NOT NULL,
+         threat_entry_type character varying(128) NOT NULL,
+         client_state character varying(42),
+         timestamp timestamp without time zone DEFAULT current_timestamp,
+         PRIMARY KEY (threat_type, platform_type, threat_entry_type)
+        )
+    """
+    drop_threat_list_stmt = """DROP TABLE IF EXISTS threat_list"""
+    
+    create_full_hashes_stmt = """
+        CREATE TABLE full_hash
+        (value BLOB NOT NULL,
+         threat_type character varying(128) NOT NULL,
+         platform_type character varying(128) NOT NULL,
+         threat_entry_type character varying(128) NOT NULL,
+         downloaded_at timestamp without time zone DEFAULT current_timestamp,
+         expires_at timestamp without time zone
+            NOT NULL DEFAULT current_timestamp,
+         malware_threat_type varchar(32),
+         PRIMARY KEY (value, threat_type, platform_type, threat_entry_type)
+        )
+    """
+    drop_full_hashes_stmt = """DROP TABLE IF EXISTS full_hash"""
+    
+    create_hash_prefix_stmt = """
+        CREATE TABLE hash_prefix
+        (value BLOB NOT NULL,
+         cue character varying(4) NOT NULL,
+         threat_type character varying(128) NOT NULL,
+         platform_type character varying(128) NOT NULL,
+         threat_entry_type character varying(128) NOT NULL,
+         timestamp timestamp without time zone DEFAULT current_timestamp,
+         negative_expires_at timestamp without time zone
+            NOT NULL DEFAULT current_timestamp,
+         PRIMARY KEY (value, threat_type, platform_type, threat_entry_type),
+         FOREIGN KEY(threat_type, platform_type, threat_entry_type)
+         REFERENCES threat_list(threat_type, platform_type, threat_entry_type)
+         ON DELETE CASCADE
+        )
+    """
+    drop_hash_prefix_stmt = """DROP TABLE IF EXISTS hash_prefix"""
+    
+    create_full_hash_cue_idx = """
+        CREATE INDEX idx_hash_prefix_cue ON hash_prefix (cue)
+    """
+    drop_full_hash_cue_idx = """DROP INDEX IF EXISTS idx_hash_prefix_cue"""
+    
+    create_full_hash_expires_idx = """
+        CREATE INDEX idx_full_hash_expires_at ON full_hash (expires_at)
+    """
+    drop_full_hash_expires_idx = """
+        DROP INDEX IF EXISTS idx_full_hash_expires_at
+    """
+    
+    create_full_hash_value_idx = """
+        CREATE INDEX idx_full_hash_value ON full_hash (value)
+    """
+    drop_full_hash_value_idx = """DROP INDEX IF EXISTS idx_full_hash_value"""
+    
+    def __init__(self, dbPath, parent=None):
+        """
+        Constructor
+        
+        @param dbPath path to store the cache DB into
+        @type str
+        @param parent reference to the parent object
+        @type QObject
+        """
+        super(SafeBrowsingCache, self).__init__(parent)
+        
+        self.__connectionName = "SafeBrowsingCache"
+        
+        if not os.path.exists(dbPath):
+            os.makedirs(dbPath)
+        
+        self.__dbFileName = os.path.join(dbPath, "SafeBrowsingCache.db")
+        preparationNeeded = not os.path.exists(self.__dbFileName)
+        
+        self.__openCacheDb()
+        if preparationNeeded:
+            self.__prepareCacheDb()
+    
+    def __openCacheDb(self):
+        """
+        Private method to open the cache database.
+        
+        @return flag indicating the open state
+        @rtype bool
+        """
+        db = QSqlDatabase.database(self.__connectionName, False)
+        if not db.isValid():
+            # the database connection is a new one
+            db = QSqlDatabase.addDatabase("QSQLITE", self.__connectionName)
+            db.setDatabaseName(self.__dbFileName)
+            opened = db.open()
+            if not opened:
+                QSqlDatabase.removeDatabase(self.__connectionName)
+        else:
+            opened = True
+        return opened
+    
+    def __prepareCacheDb(self):
+        """
+        Private method to prepare the cache database.
+        """
+        db = QSqlDatabase.database(self.__connectionName)
+        db.transaction()
+        try:
+            query = QSqlQuery(db)
+            # step 1: drop old tables
+            query.exec_(self.drop_threat_list_stmt)
+            query.exec_(self.drop_full_hashes_stmt)
+            query.exec_(self.drop_hash_prefix_stmt)
+            # step 2: drop old indices
+            query.exec_(self.drop_full_hash_cue_idx)
+            query.exec_(self.drop_full_hash_expires_idx)
+            query.exec_(self.drop_full_hash_value_idx)
+            # step 3: create tables
+            query.exec_(self.create_threat_list_stmt)
+            query.exec_(self.create_full_hashes_stmt)
+            query.exec_(self.create_hash_prefix_stmt)
+            # step 4: create indices
+            query.exec_(self.create_full_hash_cue_idx)
+            query.exec_(self.create_full_hash_expires_idx)
+            query.exec_(self.create_full_hash_value_idx)
+        finally:
+            del query
+            db.commit()
+    
+    def lookupFullHashes(self, hashValues):
+        """
+        Public method to get a list of threat lists and expiration flag
+        for the given hashes if a hash is blacklisted.
+        
+        @param hashValues list of hash values to look up
+        @type list of bytes
+        @return list of tuples containing the threat list info and the
+            expiration flag
+        @rtype list of tuple of (ThreatList, bool)
+        """
+        queryStr = """
+            SELECT threat_type, platform_type, threat_entry_type,
+            expires_at < current_timestamp AS has_expired
+            FROM full_hash WHERE value IN ({})
+        """
+        output = []
+        
+        db = QSqlDatabase.database(self.__connectionName)
+        if db.isOpen():
+            db.transaction()
+            try:
+                query = QSqlQuery(db)
+                query.prepare(
+                    queryStr.format(",".join(["?" * len(hashValues)])))
+                for hashValue in hashValues:
+                    query.addBindValue(hashValue)
+                
+                query.exec_()
+                
+                while query.next():
+                    threatType = query.value(0)
+                    platformType = query.value(1)
+                    threatEntryType = query.value(2)
+                    hasExpired = query.value(3)     # TODO: check if bool
+                    threatList = ThreatList(threatType, platformType,
+                                            threatEntryType)
+                    output.append((threatList, hasExpired))
+                del query
+            finally:
+                db.commit()
+        
+        return output
+    
+    def lookupHashPrefix(self, prefixes):
+        """
+        Public method to look up hash prefixes in the local cache.
+        
+        @param prefix hash prefix to look up
+        @type list of bytes
+        @return list of tuples containing the threat list, full hash and
+            negative cache expiration flag
+        @rtype list of tuple of (ThreatList, bytes, bool)
+        """
+        queryStr = """
+            SELECT value,threat_type,platform_type,threat_entry_type,
+            negative_expires_at < current_timestamp AS negative_cache_expired
+            FROM hash_prefix WHERE cue IN ({})
+        """
+        output = []
+        
+        db = QSqlDatabase.database(self.__connectionName)
+        if db.isOpen():
+            db.transaction()
+            try:
+                query = QSqlQuery(db)
+                query.prepare(
+                    queryStr.format(",".join(["?" * len(prefixes)])))
+                for prefix in prefixes:
+                    query.addBindValue(prefix)
+                
+                query.exec_()
+                
+                while query.next():
+                    fullHash = bytes(query.value(0))
+                    threatType = query.value(1)
+                    platformType = query.value(2)
+                    threatEntryType = query.value(3)
+                    negativeCacheExpired = query.value(4)
+                    threatList = ThreatList(threatType, platformType,
+                                            threatEntryType)
+                    output.append((threatList, fullHash, negativeCacheExpired))
+                del query
+            finally:
+                db.commit()
+        
+        return output

eric ide

mercurial