WebBrowser/SafeBrowsing/SafeBrowsingCache.py

Sat, 29 Jul 2017 19:41:16 +0200

author
Detlev Offenbach <detlev@die-offenbachs.de>
date
Sat, 29 Jul 2017 19:41:16 +0200
branch
safe_browsing
changeset 5820
b610cb5b501a
parent 5819
69fa45e95673
child 5829
d3448873ced3
permissions
-rw-r--r--

Started implementing the safe browsing manager and management dialog.

# -*- coding: utf-8 -*-

# Copyright (c) 2017 Detlev Offenbach <detlev@die-offenbachs.de>
#

"""
Module implementing a cache for Google Safe Browsing.
"""

#
# Some part of this code were ported from gglsbl.storage and adapted
# to QtSql.
#
# https://github.com/afilipovich/gglsbl
#

from __future__ import unicode_literals, division

import os

from PyQt5.QtCore import QObject, QByteArray, QCryptographicHash
from PyQt5.QtSql import QSql, QSqlDatabase, QSqlQuery

from .SafeBrowsingUtilities import toHex


class ThreatList(object):
    """
    Class implementing the threat list info.
    """
    def __init__(self, threatType, platformType, threatEntryType):
        """
        Constructor
        
        @param threatType threat type
        @type str
        @param platformType platform type
        @type str
        @param threatEntryType threat entry type
        @type str
        """
        self.threatType = threatType
        self.platformType = platformType
        self.threatEntryType = threatEntryType

    @classmethod
    def fromApiEntry(cls, entry):
        """
        Class method to instantiate a threat list given a threat list entry
        dictionary.
        
        @param entry threat list entry dictionary
        @type dict
        @return instantiated object
        @rtype ThreatList
        """
        return cls(entry['threatType'], entry['platformType'],
                   entry['threatEntryType'])

    def asTuple(self):
        """
        Public method to convert the object to a tuple.
        
        @return tuple containing the threat list info
        @rtype tuple of (str, str, str)
        """
        return (self.threatType, self.platformType, self.threatEntryType)

    def __repr__(self):
        """
        Special method to generate a printable representation.
        
        @return printable representation
        @rtype str
        """
        return '/'.join(self.asTuple())


class HashPrefixList(object):
    """
    Class implementing a container for threat list data.
    """
    def __init__(self, prefixLength, rawHashes):
        """
        Constructor
        
        @param prefixLength length of each hash prefix
        @type int
        @param rawHashes raw hash prefixes of given length concatenated and
            sorted in lexicographical order
        @type str
        """
        self.__prefixLength = prefixLength
        self.__rawHashes = rawHashes
    
    def __len__(self):
        """
        Special method to calculate the number of entries.
        
        @return length
        @rtype int
        """
        return len(self.__rawHashes) // self.__prefixLength
    
    def __iter__(self):
        """
        Special method to iterate over the raw hashes.
        
        @return iterator object
        @rtype iterator
        """
        n = self.__prefixLength
        return (self.__rawHashes[index:index + n]
                for index in range(0, len(self.__rawHashes), n)
                )


class SafeBrowsingCache(QObject):
    """
    Class implementing a cache for Google Safe Browsing.
    """
    create_threat_list_stmt = """
        CREATE TABLE threat_list
        (threat_type character varying(128) NOT NULL,
         platform_type character varying(128) NOT NULL,
         threat_entry_type character varying(128) NOT NULL,
         client_state character varying(42),
         timestamp timestamp without time zone DEFAULT current_timestamp,
         PRIMARY KEY (threat_type, platform_type, threat_entry_type)
        )
    """
    drop_threat_list_stmt = """DROP TABLE IF EXISTS threat_list"""
    
    create_full_hashes_stmt = """
        CREATE TABLE full_hash
        (value BLOB NOT NULL,
         threat_type character varying(128) NOT NULL,
         platform_type character varying(128) NOT NULL,
         threat_entry_type character varying(128) NOT NULL,
         downloaded_at timestamp without time zone DEFAULT current_timestamp,
         expires_at timestamp without time zone
            NOT NULL DEFAULT current_timestamp,
         malware_threat_type varchar(32),
         PRIMARY KEY (value, threat_type, platform_type, threat_entry_type)
        )
    """
    drop_full_hashes_stmt = """DROP TABLE IF EXISTS full_hash"""
    
    create_hash_prefix_stmt = """
        CREATE TABLE hash_prefix
        (value BLOB NOT NULL,
         cue character varying(4) NOT NULL,
         threat_type character varying(128) NOT NULL,
         platform_type character varying(128) NOT NULL,
         threat_entry_type character varying(128) NOT NULL,
         timestamp timestamp without time zone DEFAULT current_timestamp,
         negative_expires_at timestamp without time zone
            NOT NULL DEFAULT current_timestamp,
         PRIMARY KEY (value, threat_type, platform_type, threat_entry_type),
         FOREIGN KEY(threat_type, platform_type, threat_entry_type)
         REFERENCES threat_list(threat_type, platform_type, threat_entry_type)
         ON DELETE CASCADE
        )
    """
    drop_hash_prefix_stmt = """DROP TABLE IF EXISTS hash_prefix"""
    
    create_full_hash_cue_idx = """
        CREATE INDEX idx_hash_prefix_cue ON hash_prefix (cue)
    """
    drop_full_hash_cue_idx = """DROP INDEX IF EXISTS idx_hash_prefix_cue"""
    
    create_full_hash_expires_idx = """
        CREATE INDEX idx_full_hash_expires_at ON full_hash (expires_at)
    """
    drop_full_hash_expires_idx = """
        DROP INDEX IF EXISTS idx_full_hash_expires_at
    """
    
    create_full_hash_value_idx = """
        CREATE INDEX idx_full_hash_value ON full_hash (value)
    """
    drop_full_hash_value_idx = """DROP INDEX IF EXISTS idx_full_hash_value"""
    
    def __init__(self, dbPath, parent=None):
        """
        Constructor
        
        @param dbPath path to store the cache DB into
        @type str
        @param parent reference to the parent object
        @type QObject
        """
        super(SafeBrowsingCache, self).__init__(parent)
        
        self.__connectionName = "SafeBrowsingCache"
        
        if not os.path.exists(dbPath):
            os.makedirs(dbPath)
        
        self.__dbFileName = os.path.join(dbPath, "SafeBrowsingCache.db")
        preparationNeeded = not os.path.exists(self.__dbFileName)
        
        self.__openCacheDb()
        if preparationNeeded:
            self.prepareCacheDb()
    
    def close(self):
        """
        Public method to close the database.
        """
        if QSqlDatabase.database(self.__connectionName).isOpen():
            QSqlDatabase.database(self.__connectionName).close()
            QSqlDatabase.removeDatabase(self.__language)
    
    def __openCacheDb(self):
        """
        Private method to open the cache database.
        
        @return flag indicating the open state
        @rtype bool
        """
        db = QSqlDatabase.database(self.__connectionName, False)
        if not db.isValid():
            # the database connection is a new one
            db = QSqlDatabase.addDatabase("QSQLITE", self.__connectionName)
            db.setDatabaseName(self.__dbFileName)
            opened = db.open()
            if not opened:
                QSqlDatabase.removeDatabase(self.__connectionName)
        else:
            opened = True
        return opened
    
    def prepareCacheDb(self):
        """
        Public method to prepare the cache database.
        """
        db = QSqlDatabase.database(self.__connectionName)
        db.transaction()
        try:
            query = QSqlQuery(db)
            # step 1: drop old tables
            query.exec_(self.drop_threat_list_stmt)
            query.exec_(self.drop_full_hashes_stmt)
            query.exec_(self.drop_hash_prefix_stmt)
            # step 2: drop old indices
            query.exec_(self.drop_full_hash_cue_idx)
            query.exec_(self.drop_full_hash_expires_idx)
            query.exec_(self.drop_full_hash_value_idx)
            # step 3: create tables
            query.exec_(self.create_threat_list_stmt)
            query.exec_(self.create_full_hashes_stmt)
            query.exec_(self.create_hash_prefix_stmt)
            # step 4: create indices
            query.exec_(self.create_full_hash_cue_idx)
            query.exec_(self.create_full_hash_expires_idx)
            query.exec_(self.create_full_hash_value_idx)
        finally:
            del query
            db.commit()
    
    def lookupFullHashes(self, hashValues):
        """
        Public method to get a list of threat lists and expiration flag
        for the given hashes if a hash is blacklisted.
        
        @param hashValues list of hash values to look up
        @type list of bytes
        @return list of tuples containing the threat list info and the
            expiration flag
        @rtype list of tuple of (ThreatList, bool)
        """
        queryStr = """
            SELECT threat_type, platform_type, threat_entry_type,
            expires_at < current_timestamp AS has_expired
            FROM full_hash WHERE value IN ({0})
        """
        output = []
        
        db = QSqlDatabase.database(self.__connectionName)
        if db.isOpen():
            db.transaction()
            try:
                query = QSqlQuery(db)
                query.prepare(
                    queryStr.format(",".join(["?" * len(hashValues)])))
                for hashValue in hashValues:
                    query.addBindValue(QByteArray(hashValue),
                                       QSql.In | QSql.Binary)
                
                query.exec_()
                
                while query.next():
                    threatType = query.value(0)
                    platformType = query.value(1)
                    threatEntryType = query.value(2)
                    hasExpired = query.value(3)     # TODO: check if bool
                    threatList = ThreatList(threatType, platformType,
                                            threatEntryType)
                    output.append((threatList, hasExpired))
                del query
            finally:
                db.commit()
        
        return output
    
    def lookupHashPrefix(self, prefixes):
        """
        Public method to look up hash prefixes in the local cache.
        
        @param prefixes list of hash prefixes to look up
        @type list of bytes
        @return list of tuples containing the threat list, full hash and
            negative cache expiration flag
        @rtype list of tuple of (ThreatList, bytes, bool)
        """
        queryStr = """
            SELECT value,threat_type,platform_type,threat_entry_type,
            negative_expires_at < current_timestamp AS negative_cache_expired
            FROM hash_prefix WHERE cue IN ({0})
        """
        output = []
        
        db = QSqlDatabase.database(self.__connectionName)
        if db.isOpen():
            db.transaction()
            try:
                query = QSqlQuery(db)
                query.prepare(
                    queryStr.format(",".join(["?" * len(prefixes)])))
                for prefix in prefixes:
                    query.addBindValue(prefix)
                
                query.exec_()
                
                while query.next():
                    fullHash = bytes(query.value(0))
                    threatType = query.value(1)
                    platformType = query.value(2)
                    threatEntryType = query.value(3)
                    negativeCacheExpired = query.value(4)  # TODO: check if bool
                    threatList = ThreatList(threatType, platformType,
                                            threatEntryType)
                    output.append((threatList, fullHash, negativeCacheExpired))
                del query
            finally:
                db.commit()
        
        return output
    
    def storeFullHash(self, threatList, hashValue, cacheDuration,
                      malwareThreatType):
        """
        Public method to store full hash data in the cache database.
        
        @param threatList threat list info object
        @type ThreatList
        @param hashValue hash to be stored
        @type bytes
        @param cacheDuration duration the data should remain in the cache
        @type int or float
        @param malwareThreatType threat type of the malware
        @type str
        """
        insertQueryStr = """
            INSERT OR IGNORE INTO full_hash
                (value, threat_type, platform_type, threat_entry_type,
                 malware_threat_type, downloaded_at)
            VALUES
                (?, ?, ?, ?, ?, current_timestamp)
        """
        updateQueryStr = """
            UPDATE full_hash SET
                expires_at=datetime(current_timestamp, '+{0} SECONDS')
            WHERE value=? AND threat_type=? AND platform_type=? AND
            threat_entry_type=?
        """.format(int(cacheDuration))
        
        db = QSqlDatabase.database(self.__connectionName)
        if db.isOpen():
            db.transaction()
            try:
                query = QSqlQuery(db)
                query.prepare(insertQueryStr)
                query.addBindValue(QByteArray(hashValue),
                                   QSql.In | QSql.Binary)
                query.addBindValue(threatList.threatType)
                query.addBindValue(threatList.platformType)
                query.addBindValue(threatList.threatEntryType)
                query.addBindValue(malwareThreatType)
                query.exec_()
                del query
                
                query = QSqlQuery(db)
                query.prepare(updateQueryStr)
                query.addBindValue(QByteArray(hashValue),
                                   QSql.In | QSql.Binary)
                query.addBindValue(threatList.threatType)
                query.addBindValue(threatList.platformType)
                query.addBindValue(threatList.threatEntryType)
                query.exec_()
                del query
            finally:
                db.commit()
    
    def deleteHashPrefixList(self, threatList):
        """
        Public method to delete hash prefixes for a given threat list.
        
        @param threatList threat list info object
        @type ThreatList
        """
        queryStr = """
            DELETE FROM hash_prefix
                WHERE threat_type=? AND platform_type=? AND threat_entry_type=?
        """
        
        db = QSqlDatabase.database(self.__connectionName)
        if db.isOpen():
            db.transaction()
            try:
                query = QSqlQuery(db)
                query.prepare(queryStr)
                query.addBindValue(threatList.threatType)
                query.addBindValue(threatList.platformType)
                query.addBindValue(threatList.threatEntryType)
                query.exec_()
                del query
            finally:
                db.commit()
    
    def cleanupFullHashes(self, keepExpiredFor=43200):
        """
        Public method to clean up full hash entries expired more than the
        given time.
        
        @param keepExpiredFor time period in seconds of entries to be expired
        @type int or float
        """
        queryStr = """
            DELETE FROM full_hash
                WHERE expires_at=datetime(current_timestamp, '{0} SECONDS')
        """.format(int(keepExpiredFor))
        
        db = QSqlDatabase.database(self.__connectionName)
        if db.isOpen():
            db.transaction()
            try:
                query = QSqlQuery(db)
                query.prepare(queryStr)
                query.exec_()
                del query
            finally:
                db.commit()
    
    def updateHashPrefixExpiration(self, threatList, hashPrefix,
                                   negativeCacheDuration):
        """
        Public method to update the hash prefix expiration time.
        
        @param threatList threat list info object
        @type ThreatList
        @param hashPrefix hash prefix
        @type bytes
        @param negativeCacheDuration time in seconds the entry should remain
            in the cache
        @type int or float
        """
        queryStr = """
            UPDATE hash_prefix
            SET negative_expires_at=datetime(current_timestamp, '+{0} SECONDS')
            WHERE value=? AND threat_type=? AND platform_type=? AND
            threat_entry_type=?
        """.format(int(negativeCacheDuration))
        
        db = QSqlDatabase.database(self.__connectionName)
        if db.isOpen():
            db.transaction()
            try:
                query = QSqlQuery(db)
                query.prepare(queryStr)
                query.addBindValue(QByteArray(hashPrefix),
                                   QSql.In | QSql.Binary)
                query.addBindValue(threatList.threatType)
                query.addBindValue(threatList.platformType)
                query.addBindValue(threatList.threatEntryType)
                query.exec_()
                del query
            finally:
                db.commit()
    
    def getThreatLists(self):
        """
        Public method to get the available threat lists.
        
        @return list of available threat lists
        @rtype list of tuples of (ThreatList, str)
        """
        queryStr = """
            SELECT threat_type,platform_type,threat_entry_type,client_state
            FROM threat_list
        """
        output = []
        
        db = QSqlDatabase.database(self.__connectionName)
        if db.isOpen():
            db.transaction()
            try:
                query = QSqlQuery(db)
                query.prepare(queryStr)
                
                query.exec_()
                
                while query.next():
                    threatType = query.value(0)
                    platformType = query.value(1)
                    threatEntryType = query.value(2)
                    clientState = query.value(3)
                    threatList = ThreatList(threatType, platformType,
                                            threatEntryType)
                    output.append((threatList, clientState))
                del query
            finally:
                db.commit()
        
        return output
    
    def addThreatList(self, threatList):
        """
        Public method to add a threat list to the cache.
        
        @param threatList threat list to be added
        @type ThreatList
        """
        queryStr = """
            INSERT OR IGNORE INTO threat_list
                (threat_type, platform_type, threat_entry_type, timestamp)
            VALUES (?, ?, ?, current_timestamp)
        """
        
        db = QSqlDatabase.database(self.__connectionName)
        if db.isOpen():
            db.transaction()
            try:
                query = QSqlQuery(db)
                query.prepare(queryStr)
                query.addBindValue(threatList.threatType)
                query.addBindValue(threatList.platformType)
                query.addBindValue(threatList.threatEntryType)
                query.exec_()
                del query
            finally:
                db.commit()
    
    def deleteThreatList(self, threatList):
        """
        Public method to delete a threat list from the cache.
        
        @param threatList threat list to be deleted
        @type ThreatList
        """
        queryStr = """
            DELETE FROM threat_list
                WHERE threat_type=? AND platform_type=? AND threat_entry_type=?
        """
        
        db = QSqlDatabase.database(self.__connectionName)
        if db.isOpen():
            db.transaction()
            try:
                query = QSqlQuery(db)
                query.prepare(queryStr)
                query.addBindValue(threatList.threatType)
                query.addBindValue(threatList.platformType)
                query.addBindValue(threatList.threatEntryType)
                query.exec_()
                del query
            finally:
                db.commit()
    
    def updateThreatListClientState(self, threatList, clientState):
        """
        Public method to update the client state of a threat list.
        
        @param threatList threat list to update the client state for
        @type ThreatList
        @param clientState new client state
        @type str
        """
        queryStr = """
            UPDATE threat_list SET timestamp=current_timestamp, client_state=?
            WHERE threat_type=? AND platform_type=? AND threat_entry_type=?
        """
        
        db = QSqlDatabase.database(self.__connectionName)
        if db.isOpen():
            db.transaction()
            try:
                query = QSqlQuery(db)
                query.prepare(queryStr)
                query.addBindValue(clientState)
                query.addBindValue(threatList.threatType)
                query.addBindValue(threatList.platformType)
                query.addBindValue(threatList.threatEntryType)
                query.exec_()
                del query
            finally:
                db.commit()
    
    def hashPrefixListChecksum(self, threatList):
        """
        Public method to calculate the SHA256 checksum for an alphabetically
        sorted concatenated list of hash prefixes.
        
        @param threatList threat list to calculate checksum for
        @type ThreatList
        @return SHA256 checksum
        @rtype bytes
        """
        queryStr = """
            SELECT value FROM hash_prefix
                WHERE threat_type=? AND platform_type=? AND threat_entry_type=?
                ORDER BY value
        """
        checksum = None
        
        db = QSqlDatabase.database(self.__connectionName)
        if db.isOpen():
            db.transaction()
            hash = QCryptographicHash(QCryptographicHash.Sha256)
            try:
                query = QSqlQuery(db)
                query.prepare(queryStr)
                query.addBindValue(threatList.threatType)
                query.addBindValue(threatList.platformType)
                query.addBindValue(threatList.threatEntryType)
                
                query.exec_()
                
                while query.next():
                    hash.addData(query.value(0))
                del query
            finally:
                db.commit()
            
            checksum = bytes(hash.result())
        
        return checksum
    
    def populateHashPrefixList(self, threatList, prefixes):
        """
        Public method to populate the hash prefixes for a threat list.
        
        @param threatList threat list of the hash prefixes
        @type ThreatList
        @param prefixes list of hash prefixes to be inserted
        @type HashPrefixList
        """
        queryStr = """
            INSERT INTO hash_prefix
                (value, cue, threat_type, platform_type, threat_entry_type,
                 timestamp)
                VALUES (?, ?, ?, ?, ?, current_timestamp)
        """
        
        db = QSqlDatabase.database(self.__connectionName)
        if db.isOpen():
            db.transaction()
            try:
                for prefix in prefixes:
                    query = QSqlQuery(db)
                    query.prepare(queryStr)
                    query.addBindValue(QByteArray(prefix),
                                       QSql.In | QSql.Binary)
                    query.addBindValue(toHex(prefix[:4]))
                    query.addBindValue(threatList.threatType)
                    query.addBindValue(threatList.platformType)
                    query.addBindValue(threatList.threatEntryType)
                    query.exec_()
                    del query
            finally:
                db.commit()
    
    def getHashPrefixValuesToRemove(self, threatList, indexes):
        """
        Public method to get the hash prefix values to be removed from the
        cache.
        
        @param threatList threat list to remove prefixes from
        @type ThreatList
        @param indexes list of indexes of prefixes to be removed
        @type list of int
        @return list of hash prefixes to be removed
        @rtype list of bytes
        """
        queryStr = """
            SELECT value FROM hash_prefix
                WHERE threat_type=? AND platform_type=? AND threat_entry_type=?
                ORDER BY value
        """
        indexes = set(indexes)
        output = []
        
        db = QSqlDatabase.database(self.__connectionName)
        if db.isOpen():
            db.transaction()
            try:
                query = QSqlQuery(db)
                query.prepare(queryStr)
                query.addBindValue(threatList.threatType)
                query.addBindValue(threatList.platformType)
                query.addBindValue(threatList.threatEntryType)
                
                query.exec_()
                
                index = 0
                while query.next():
                    if index in indexes:
                        prefix = bytes(query.value(0))
                        output.append(prefix)
                    index += 1
                del query
            finally:
                db.commit()
        
        return output
    
    def removeHashPrefixIndices(self, threatList, indexes):
        """
        Public method to remove hash prefixes from the cache.
        
        @param threatList threat list to delete hash prefixes of
        @type ThreatList
        @param indexes list of indexes of prefixes to be removed
        @type list of int
        """
        queryStr = """
            DELETE FROM hash_prefix
            WHERE threat_type=? AND platform_type=? AND
            threat_entry_type=? AND value IN ({0})
        """
        batchSize = 40
        
        prefixesToRemove = self.getHashPrefixValuesToRemove(
            threatList, indexes)
        if prefixesToRemove:
            db = QSqlDatabase.database(self.__connectionName)
            if db.isOpen():
                db.transaction()
                try:
                    for index in range(0, len(prefixesToRemove), batchSize):
                        removeBatch = \
                            prefixesToRemove[index:(index + batchSize)]
                        
                        query = QSqlQuery(db)
                        query.prepare(
                            queryStr.format(",".join(["?"] * len(removeBatch)))
                        )
                        query.addBindValue(threatList.threatType)
                        query.addBindValue(threatList.platformType)
                        query.addBindValue(threatList.threatEntryType)
                        for prefix in removeBatch:
                            query.addBindValue(QByteArray(prefix),
                                               QSql.In | QSql.Binary)
                        query.exec_()
                        del query
                finally:
                    db.commit()

eric ide

mercurial