WebBrowser/SafeBrowsing/SafeBrowsingCache.py

Tue, 25 Jul 2017 19:22:36 +0200

author
Detlev Offenbach <detlev@die-offenbachs.de>
date
Tue, 25 Jul 2017 19:22:36 +0200
branch
safe_browsing
changeset 5818
cae9956be67e
parent 5817
a5f6c9128500
child 5819
69fa45e95673
permissions
-rw-r--r--

Continued implementing the SafeBrowsingCache class.

# -*- coding: utf-8 -*-

# Copyright (c) 2017 Detlev Offenbach <detlev@die-offenbachs.de>
#

"""
Module implementing a cache for Google Safe Browsing.
"""

#
# Some part of this code were ported from gglsbl.storage and adapted
# to QtSql.
#
# https://github.com/afilipovich/gglsbl
#

from __future__ import unicode_literals

import os

from PyQt5.QtCore import QObject
from PyQt5.QtSql import QSql, QSqlDatabase, QSqlQuery


class ThreatList(object):
    """
    Class implementing the threat list info.
    """
    def __init__(self, threatType, platformType, threatEntryType):
        """
        Constructor
        
        @param threatType threat type
        @type str
        @param platformType platform type
        @type str
        @param threatEntryType threat entry type
        @type str
        """
        self.threatType = threatType
        self.platformType = platformType
        self.threatEntryType = threatEntryType

    @classmethod
    def fromApiEntry(cls, entry):
        """
        Class method to instantiate a threat list given a threat list entry
        dictionary.
        
        @param entry threat list entry dictionary
        @type dict
        @return instantiated object
        @rtype ThreatList
        """
        return cls(entry['threatType'], entry['platformType'],
                   entry['threatEntryType'])

    def asTuple(self):
        """
        Public method to convert the object to a tuple.
        
        @return tuple containing the threat list info
        @rtype tuple of (str, str, str)
        """
        return (self.threatType, self.platformType, self.threatEntryType)

    def __repr__(self):
        """
        Special method to generate a printable representation.
        
        @return printable representation
        @rtype str
        """
        return '/'.join(self.asTuple())


class SafeBrowsingCache(QObject):
    """
    Class implementing a cache for Google Safe Browsing.
    """
    create_threat_list_stmt = """
        CREATE TABLE threat_list
        (threat_type character varying(128) NOT NULL,
         platform_type character varying(128) NOT NULL,
         threat_entry_type character varying(128) NOT NULL,
         client_state character varying(42),
         timestamp timestamp without time zone DEFAULT current_timestamp,
         PRIMARY KEY (threat_type, platform_type, threat_entry_type)
        )
    """
    drop_threat_list_stmt = """DROP TABLE IF EXISTS threat_list"""
    
    create_full_hashes_stmt = """
        CREATE TABLE full_hash
        (value BLOB NOT NULL,
         threat_type character varying(128) NOT NULL,
         platform_type character varying(128) NOT NULL,
         threat_entry_type character varying(128) NOT NULL,
         downloaded_at timestamp without time zone DEFAULT current_timestamp,
         expires_at timestamp without time zone
            NOT NULL DEFAULT current_timestamp,
         malware_threat_type varchar(32),
         PRIMARY KEY (value, threat_type, platform_type, threat_entry_type)
        )
    """
    drop_full_hashes_stmt = """DROP TABLE IF EXISTS full_hash"""
    
    create_hash_prefix_stmt = """
        CREATE TABLE hash_prefix
        (value BLOB NOT NULL,
         cue character varying(4) NOT NULL,
         threat_type character varying(128) NOT NULL,
         platform_type character varying(128) NOT NULL,
         threat_entry_type character varying(128) NOT NULL,
         timestamp timestamp without time zone DEFAULT current_timestamp,
         negative_expires_at timestamp without time zone
            NOT NULL DEFAULT current_timestamp,
         PRIMARY KEY (value, threat_type, platform_type, threat_entry_type),
         FOREIGN KEY(threat_type, platform_type, threat_entry_type)
         REFERENCES threat_list(threat_type, platform_type, threat_entry_type)
         ON DELETE CASCADE
        )
    """
    drop_hash_prefix_stmt = """DROP TABLE IF EXISTS hash_prefix"""
    
    create_full_hash_cue_idx = """
        CREATE INDEX idx_hash_prefix_cue ON hash_prefix (cue)
    """
    drop_full_hash_cue_idx = """DROP INDEX IF EXISTS idx_hash_prefix_cue"""
    
    create_full_hash_expires_idx = """
        CREATE INDEX idx_full_hash_expires_at ON full_hash (expires_at)
    """
    drop_full_hash_expires_idx = """
        DROP INDEX IF EXISTS idx_full_hash_expires_at
    """
    
    create_full_hash_value_idx = """
        CREATE INDEX idx_full_hash_value ON full_hash (value)
    """
    drop_full_hash_value_idx = """DROP INDEX IF EXISTS idx_full_hash_value"""
    
    def __init__(self, dbPath, parent=None):
        """
        Constructor
        
        @param dbPath path to store the cache DB into
        @type str
        @param parent reference to the parent object
        @type QObject
        """
        super(SafeBrowsingCache, self).__init__(parent)
        
        self.__connectionName = "SafeBrowsingCache"
        
        if not os.path.exists(dbPath):
            os.makedirs(dbPath)
        
        self.__dbFileName = os.path.join(dbPath, "SafeBrowsingCache.db")
        preparationNeeded = not os.path.exists(self.__dbFileName)
        
        self.__openCacheDb()
        if preparationNeeded:
            self.__prepareCacheDb()
    
    def close(self):
        """
        Public method to close the database.
        """
        if QSqlDatabase.database(self.__connectionName).isOpen():
            QSqlDatabase.database(self.__connectionName).close()
            QSqlDatabase.removeDatabase(self.__language)
    
    def __openCacheDb(self):
        """
        Private method to open the cache database.
        
        @return flag indicating the open state
        @rtype bool
        """
        db = QSqlDatabase.database(self.__connectionName, False)
        if not db.isValid():
            # the database connection is a new one
            db = QSqlDatabase.addDatabase("QSQLITE", self.__connectionName)
            db.setDatabaseName(self.__dbFileName)
            opened = db.open()
            if not opened:
                QSqlDatabase.removeDatabase(self.__connectionName)
        else:
            opened = True
        return opened
    
    def __prepareCacheDb(self):
        """
        Private method to prepare the cache database.
        """
        db = QSqlDatabase.database(self.__connectionName)
        db.transaction()
        try:
            query = QSqlQuery(db)
            # step 1: drop old tables
            query.exec_(self.drop_threat_list_stmt)
            query.exec_(self.drop_full_hashes_stmt)
            query.exec_(self.drop_hash_prefix_stmt)
            # step 2: drop old indices
            query.exec_(self.drop_full_hash_cue_idx)
            query.exec_(self.drop_full_hash_expires_idx)
            query.exec_(self.drop_full_hash_value_idx)
            # step 3: create tables
            query.exec_(self.create_threat_list_stmt)
            query.exec_(self.create_full_hashes_stmt)
            query.exec_(self.create_hash_prefix_stmt)
            # step 4: create indices
            query.exec_(self.create_full_hash_cue_idx)
            query.exec_(self.create_full_hash_expires_idx)
            query.exec_(self.create_full_hash_value_idx)
        finally:
            del query
            db.commit()
    
    def lookupFullHashes(self, hashValues):
        """
        Public method to get a list of threat lists and expiration flag
        for the given hashes if a hash is blacklisted.
        
        @param hashValues list of hash values to look up
        @type list of bytes
        @return list of tuples containing the threat list info and the
            expiration flag
        @rtype list of tuple of (ThreatList, bool)
        """
        queryStr = """
            SELECT threat_type, platform_type, threat_entry_type,
            expires_at < current_timestamp AS has_expired
            FROM full_hash WHERE value IN ({0})
        """
        output = []
        
        db = QSqlDatabase.database(self.__connectionName)
        if db.isOpen():
            db.transaction()
            try:
                query = QSqlQuery(db)
                query.prepare(
                    queryStr.format(",".join(["?" * len(hashValues)])))
                for hashValue in hashValues:
                    query.addBindValue(hashValue, QSql.In | QSql.Binary)
                
                query.exec_()
                
                while query.next():
                    threatType = query.value(0)
                    platformType = query.value(1)
                    threatEntryType = query.value(2)
                    hasExpired = query.value(3)     # TODO: check if bool
                    threatList = ThreatList(threatType, platformType,
                                            threatEntryType)
                    output.append((threatList, hasExpired))
                del query
            finally:
                db.commit()
        
        return output
    
    def lookupHashPrefix(self, prefixes):
        """
        Public method to look up hash prefixes in the local cache.
        
        @param prefixes list of hash prefixes to look up
        @type list of bytes
        @return list of tuples containing the threat list, full hash and
            negative cache expiration flag
        @rtype list of tuple of (ThreatList, bytes, bool)
        """
        queryStr = """
            SELECT value,threat_type,platform_type,threat_entry_type,
            negative_expires_at < current_timestamp AS negative_cache_expired
            FROM hash_prefix WHERE cue IN ({0})
        """
        output = []
        
        db = QSqlDatabase.database(self.__connectionName)
        if db.isOpen():
            db.transaction()
            try:
                query = QSqlQuery(db)
                query.prepare(
                    queryStr.format(",".join(["?" * len(prefixes)])))
                for prefix in prefixes:
                    query.addBindValue(prefix)
                
                query.exec_()
                
                while query.next():
                    fullHash = bytes(query.value(0))
                    threatType = query.value(1)
                    platformType = query.value(2)
                    threatEntryType = query.value(3)
                    negativeCacheExpired = query.value(4)
                    threatList = ThreatList(threatType, platformType,
                                            threatEntryType)
                    output.append((threatList, fullHash, negativeCacheExpired))
                del query
            finally:
                db.commit()
        
        return output
    
    def storeFullHash(self, threatList, hashValue, cacheDuration,
                      malwareThreatType):
        """
        Public method to store full hash data in the cache database.
        
        @param threatList threat list info object
        @type ThreatList
        @param hashValue hash to be stored
        @type bytes
        @param cacheDuration duration the data should remain in the cache
        @type int or float
        @param malwareThreatType threat type of the malware
        @type str
        """
        insertQueryStr = """
            INSERT OR IGNORE INTO full_hash
                (value, threat_type, platform_type, threat_entry_type,
                 malware_threat_type, downloaded_at)
            VALUES
                (?, ?, ?, ?, ?, current_timestamp)
        """
        updateQueryStr = """
            UPDATE full_hash SET
                expires_at=datetime(current_timestamp, '+{0} SECONDS')
            WHERE value=? AND threat_type=? AND platform_type=? AND
            threat_entry_type=?
        """
        
        db = QSqlDatabase.database(self.__connectionName)
        if db.isOpen():
            db.transaction()
            try:
                query = QSqlQuery(db)
                query.prepare(insertQueryStr)
                query.addBindValue(hashValue, QSql.In | QSql.Binary)
                query.addBindValue(threatList.threatType)
                query.addBindValue(threatList.platformType)
                query.addBindValue(threatList.threatEntryType)
                query.addBindValue(malwareThreatType)
                query.exec_()
                del query
                
                query = QSqlQuery(db)
                query.prepare(updateQueryStr.format(int(cacheDuration)))
                query.addBindValue(hashValue, QSql.In | QSql.Binary)
                query.addBindValue(threatList.threatType)
                query.addBindValue(threatList.platformType)
                query.addBindValue(threatList.threatEntryType)
                query.exec_()
                del query
            finally:
                db.commit()
    
    def deleteHashPrefixList(self, threatList):
        """
        Public method to delete hash prefixes for a given threat list.
        
        @param threatList threat list info object
        @type ThreatList
        """
        queryStr = """
            DELETE FROM hash_prefix
                WHERE threat_type=? AND platform_type=? AND threat_entry_type=?
        """
        
        db = QSqlDatabase.database(self.__connectionName)
        if db.isOpen():
            db.transaction()
            try:
                query = QSqlQuery(db)
                query.prepare(queryStr)
                query.addBindValue(threatList.threatType)
                query.addBindValue(threatList.platformType)
                query.addBindValue(threatList.threatEntryType)
                query.exec_()
                del query
            finally:
                db.commit()
    
    def cleanupFullHashes(self, keepExpiredFor=43200):
        """
        Public method to clean up full hash entries expired more than the
        given time.
        
        @param keepExpiredFor time period in seconds of entries to be expired
        @type int or float
        """
        queryStr = """
            DELETE FROM full_hash
                WHERE expires_at=datetime(current_timestamp, '{0} SECONDS')
        """
        
        db = QSqlDatabase.database(self.__connectionName)
        if db.isOpen():
            db.transaction()
            try:
                query = QSqlQuery(db)
                query.prepare(queryStr.format(int(keepExpiredFor)))
                query.exec_()
                del query
            finally:
                db.commit()

eric ide

mercurial