WebBrowser/SafeBrowsing/SafeBrowsingCache.py

Mon, 24 Jul 2017 18:40:07 +0200

author
Detlev Offenbach <detlev@die-offenbachs.de>
date
Mon, 24 Jul 2017 18:40:07 +0200
branch
safe_browsing
changeset 5817
a5f6c9128500
child 5818
cae9956be67e
permissions
-rw-r--r--

Started implementing the SafeBrowsingCache class.

# -*- coding: utf-8 -*-

# Copyright (c) 2017 Detlev Offenbach <detlev@die-offenbachs.de>
#

"""
Module implementing a cache for Google Safe Browsing.
"""

#
# Some part of this code were ported from gglsbl.storage and adapted
# to QtSql.
#
# https://github.com/afilipovich/gglsbl
#

from __future__ import unicode_literals

import os

from PyQt5.QtCore import QObject
from PyQt5.QtSql import QSqlDatabase, QSqlQuery


class ThreatList(object):
    """
    Class implementing the threat list info.
    """
    def __init__(self, threatType, platformType, threatEntryType):
        """
        Constructor
        
        @param threatType threat type
        @type str
        @param platformType platform type
        @type str
        @param threatEntryType threat entry type
        @type str
        """
        self.__threatType = threatType
        self.__platformType = platformType
        self.__threatEntryType = threatEntryType

    @classmethod
    def fromApiEntry(cls, entry):
        """
        Class method to instantiate a threat list given a threat list entry
        dictionary.
        
        @param entry threat list entry dictionary
        @type dict
        @return instantiated object
        @rtype ThreatList
        """
        return cls(entry['threatType'], entry['platformType'],
                   entry['threatEntryType'])

    def asTuple(self):
        """
        Public method to convert the object to a tuple.
        
        @return tuple containing the threat list info
        @rtype tuple of (str, str, str)
        """
        return (self.__threatType, self.platformType, self.__threatEntryType)

    def __repr__(self):
        """
        Special method to generate a printable representation.
        
        @return printable representation
        @rtype str
        """
        return '/'.join(self.asTuple())


class SafeBrowsingCache(QObject):
    """
    Class implementing a cache for Google Safe Browsing.
    """
    
    create_threat_list_stmt = """
        CREATE TABLE threat_list
        (threat_type character varying(128) NOT NULL,
         platform_type character varying(128) NOT NULL,
         threat_entry_type character varying(128) NOT NULL,
         client_state character varying(42),
         timestamp timestamp without time zone DEFAULT current_timestamp,
         PRIMARY KEY (threat_type, platform_type, threat_entry_type)
        )
    """
    drop_threat_list_stmt = """DROP TABLE IF EXISTS threat_list"""
    
    create_full_hashes_stmt = """
        CREATE TABLE full_hash
        (value BLOB NOT NULL,
         threat_type character varying(128) NOT NULL,
         platform_type character varying(128) NOT NULL,
         threat_entry_type character varying(128) NOT NULL,
         downloaded_at timestamp without time zone DEFAULT current_timestamp,
         expires_at timestamp without time zone
            NOT NULL DEFAULT current_timestamp,
         malware_threat_type varchar(32),
         PRIMARY KEY (value, threat_type, platform_type, threat_entry_type)
        )
    """
    drop_full_hashes_stmt = """DROP TABLE IF EXISTS full_hash"""
    
    create_hash_prefix_stmt = """
        CREATE TABLE hash_prefix
        (value BLOB NOT NULL,
         cue character varying(4) NOT NULL,
         threat_type character varying(128) NOT NULL,
         platform_type character varying(128) NOT NULL,
         threat_entry_type character varying(128) NOT NULL,
         timestamp timestamp without time zone DEFAULT current_timestamp,
         negative_expires_at timestamp without time zone
            NOT NULL DEFAULT current_timestamp,
         PRIMARY KEY (value, threat_type, platform_type, threat_entry_type),
         FOREIGN KEY(threat_type, platform_type, threat_entry_type)
         REFERENCES threat_list(threat_type, platform_type, threat_entry_type)
         ON DELETE CASCADE
        )
    """
    drop_hash_prefix_stmt = """DROP TABLE IF EXISTS hash_prefix"""
    
    create_full_hash_cue_idx = """
        CREATE INDEX idx_hash_prefix_cue ON hash_prefix (cue)
    """
    drop_full_hash_cue_idx = """DROP INDEX IF EXISTS idx_hash_prefix_cue"""
    
    create_full_hash_expires_idx = """
        CREATE INDEX idx_full_hash_expires_at ON full_hash (expires_at)
    """
    drop_full_hash_expires_idx = """
        DROP INDEX IF EXISTS idx_full_hash_expires_at
    """
    
    create_full_hash_value_idx = """
        CREATE INDEX idx_full_hash_value ON full_hash (value)
    """
    drop_full_hash_value_idx = """DROP INDEX IF EXISTS idx_full_hash_value"""
    
    def __init__(self, dbPath, parent=None):
        """
        Constructor
        
        @param dbPath path to store the cache DB into
        @type str
        @param parent reference to the parent object
        @type QObject
        """
        super(SafeBrowsingCache, self).__init__(parent)
        
        self.__connectionName = "SafeBrowsingCache"
        
        if not os.path.exists(dbPath):
            os.makedirs(dbPath)
        
        self.__dbFileName = os.path.join(dbPath, "SafeBrowsingCache.db")
        preparationNeeded = not os.path.exists(self.__dbFileName)
        
        self.__openCacheDb()
        if preparationNeeded:
            self.__prepareCacheDb()
    
    def __openCacheDb(self):
        """
        Private method to open the cache database.
        
        @return flag indicating the open state
        @rtype bool
        """
        db = QSqlDatabase.database(self.__connectionName, False)
        if not db.isValid():
            # the database connection is a new one
            db = QSqlDatabase.addDatabase("QSQLITE", self.__connectionName)
            db.setDatabaseName(self.__dbFileName)
            opened = db.open()
            if not opened:
                QSqlDatabase.removeDatabase(self.__connectionName)
        else:
            opened = True
        return opened
    
    def __prepareCacheDb(self):
        """
        Private method to prepare the cache database.
        """
        db = QSqlDatabase.database(self.__connectionName)
        db.transaction()
        try:
            query = QSqlQuery(db)
            # step 1: drop old tables
            query.exec_(self.drop_threat_list_stmt)
            query.exec_(self.drop_full_hashes_stmt)
            query.exec_(self.drop_hash_prefix_stmt)
            # step 2: drop old indices
            query.exec_(self.drop_full_hash_cue_idx)
            query.exec_(self.drop_full_hash_expires_idx)
            query.exec_(self.drop_full_hash_value_idx)
            # step 3: create tables
            query.exec_(self.create_threat_list_stmt)
            query.exec_(self.create_full_hashes_stmt)
            query.exec_(self.create_hash_prefix_stmt)
            # step 4: create indices
            query.exec_(self.create_full_hash_cue_idx)
            query.exec_(self.create_full_hash_expires_idx)
            query.exec_(self.create_full_hash_value_idx)
        finally:
            del query
            db.commit()
    
    def lookupFullHashes(self, hashValues):
        """
        Public method to get a list of threat lists and expiration flag
        for the given hashes if a hash is blacklisted.
        
        @param hashValues list of hash values to look up
        @type list of bytes
        @return list of tuples containing the threat list info and the
            expiration flag
        @rtype list of tuple of (ThreatList, bool)
        """
        queryStr = """
            SELECT threat_type, platform_type, threat_entry_type,
            expires_at < current_timestamp AS has_expired
            FROM full_hash WHERE value IN ({})
        """
        output = []
        
        db = QSqlDatabase.database(self.__connectionName)
        if db.isOpen():
            db.transaction()
            try:
                query = QSqlQuery(db)
                query.prepare(
                    queryStr.format(",".join(["?" * len(hashValues)])))
                for hashValue in hashValues:
                    query.addBindValue(hashValue)
                
                query.exec_()
                
                while query.next():
                    threatType = query.value(0)
                    platformType = query.value(1)
                    threatEntryType = query.value(2)
                    hasExpired = query.value(3)     # TODO: check if bool
                    threatList = ThreatList(threatType, platformType,
                                            threatEntryType)
                    output.append((threatList, hasExpired))
                del query
            finally:
                db.commit()
        
        return output
    
    def lookupHashPrefix(self, prefixes):
        """
        Public method to look up hash prefixes in the local cache.
        
        @param prefix hash prefix to look up
        @type list of bytes
        @return list of tuples containing the threat list, full hash and
            negative cache expiration flag
        @rtype list of tuple of (ThreatList, bytes, bool)
        """
        queryStr = """
            SELECT value,threat_type,platform_type,threat_entry_type,
            negative_expires_at < current_timestamp AS negative_cache_expired
            FROM hash_prefix WHERE cue IN ({})
        """
        output = []
        
        db = QSqlDatabase.database(self.__connectionName)
        if db.isOpen():
            db.transaction()
            try:
                query = QSqlQuery(db)
                query.prepare(
                    queryStr.format(",".join(["?" * len(prefixes)])))
                for prefix in prefixes:
                    query.addBindValue(prefix)
                
                query.exec_()
                
                while query.next():
                    fullHash = bytes(query.value(0))
                    threatType = query.value(1)
                    platformType = query.value(2)
                    threatEntryType = query.value(3)
                    negativeCacheExpired = query.value(4)
                    threatList = ThreatList(threatType, platformType,
                                            threatEntryType)
                    output.append((threatList, fullHash, negativeCacheExpired))
                del query
            finally:
                db.commit()
        
        return output

eric ide

mercurial