OllamaInterface/OllamaClient.py

Mon, 05 Aug 2024 18:37:16 +0200

author
Detlev Offenbach <detlev@die-offenbachs.de>
date
Mon, 05 Aug 2024 18:37:16 +0200
changeset 4
7dd1b9cd3150
parent 3
ca28466a186d
child 5
6e8af43d537d
permissions
-rw-r--r--

Implemented most of the Chat History widgets.

# -*- coding: utf-8 -*-

# Copyright (c) 2024 Detlev Offenbach <detlev@die-offenbachs.de>
#

"""
Module implementing the 'ollama' client.
"""

import contextlib
import datetime
import enum
import json

from PyQt6.QtCore import (
    QCoreApplication,
    QObject,
    QThread,
    QTimer,
    QUrl,
    pyqtSignal,
    pyqtSlot,
)
from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkReply, QNetworkRequest

from eric7.EricNetwork.EricNetworkProxyFactory import proxyAuthenticationRequired


class OllamaClientState(enum.Enum):
    """
    Class defining the various client states.
    """

    Waiting = 0
    Requesting = 1
    Receiving = 2
    Finished = 3


class OllamaClient(QObject):
    """
    Class implementing the 'ollama' client.

    @signal replyReceived(content:str, role:str) emitted after a response from the
        'ollama' server was received
    @signal modelsList(modelNames:list[str]) emitted after the list of model
        names was obtained from the 'ollama' server
    @signal detailedModelsList(models:list[dict]) emitted after the list of
        models was obtained from the 'ollama' server giving some model details
    @signal runningModelsList(models:list[dict]) emitted after the list of
        running models was obtained from the 'ollama' server giving some model
        execution details
    @signal pullStatus(msg:str, id:str, total:int, completed:int) emitted to indicate
        the status of a pull request as reported by the 'ollama' server
    @signal serverVersion(version:str) emitted after the server version was obtained
        from the 'ollama' server
    @signal finished() emitted to indicate the completion of a request
    @signal errorOccurred(error:str) emitted to indicate a network error occurred
        while processing the request
    @signal serverStateChanged(ok:bool) emitted to indicate a change of the server
        responsiveness
    """

    replyReceived = pyqtSignal(str, str)
    modelsList = pyqtSignal(list)
    detailedModelsList = pyqtSignal(list)
    runningModelsList = pyqtSignal(list)
    pullStatus = pyqtSignal(str, str, int, int)
    serverVersion = pyqtSignal(str)
    finished = pyqtSignal()
    errorOccurred = pyqtSignal(str)
    serverStateChanged = pyqtSignal(bool)

    def __init__(self, plugin, parent=None):
        """
        Constructor

        @param plugin reference to the plugin object
        @type PluginOllamaInterface
        @param parent reference to the parent object (defaults to None)
        @type QObject (optional)
        """
        super().__init__(parent=parent)

        self.__plugin = plugin
        self.__replies = []

        self.__networkManager = QNetworkAccessManager(self)
        self.__networkManager.proxyAuthenticationRequired.connect(
            proxyAuthenticationRequired
        )

        self.__serverResponding = False
        self.__heartbeatTimer = QTimer(self)
        self.__heartbeatTimer.timeout.connect(self.__periodicHeartbeat)

        self.__state = OllamaClientState.Waiting

        self.__serverResponding = False  # start with a faulty state

        self.__plugin.preferencesChanged.connect(self.__setHeartbeatTimer)
        self.__setHeartbeatTimer()

    def chat(self, model, messages):
        """
        Public method to request a chat completion from the 'ollama' server.

        @param model name of the model to be used
        @type str
        @param messages list of message objects
        @type list of dict
        """
        # TODO: not implemented yet
        ollamaRequest = {
            "model": model,
            "messages": messages,
        }
        self.__sendRequest(
            "chat", data=ollamaRequest, processResponse=self.__processChatResponse
        )

    def __processChatResponse(self, response):
        """
        Private method to process the chat response of the 'ollama' server.

        @param response dictionary containing the chat response
        @type dict
        """
        with contextlib.suppress(KeyError):
            message = response["message"]
            if message:
                self.replyReceived.emit(message["content"], message["role"])

    def generate(self, model, prompt, suffix=None):
        """
        Public method to request to generate a completion from the 'ollama' server.

        @param model name of the model to be used
        @type str
        @param prompt prompt to generate a response for
        @type str
        @param suffix text after the model response (defaults to None)
        @type str (optional)
        """
        # TODO: not implemented yet
        ollamaRequest = {
            "model": model,
            "prompt": prompt,
        }
        if suffix is not None:
            ollamaRequest["suffix"] = suffix
        self.__sendRequest(
            "generate",
            data=ollamaRequest,
            processResponse=self.__processGenerateResponse,
        )

    def __processGenerateResponse(self, response):
        """
        Private method to process the generate response of the 'ollama' server.

        @param response dictionary containing the generate response
        @type dict
        """
        with contextlib.suppress(KeyError):
            self.replyReceived.emit(response["response"], "")

    def pull(self, model):
        """
        Public method to ask the 'ollama' server to pull the given model.

        @param model name of the model
        @type str
        """
        # TODO: not implemented yet
        ollamaRequest = {
            "name": model,
        }
        self.__sendRequest(
            "pull", data=ollamaRequest, processResponse=self.__processPullResponse
        )

    def __processPullResponse(self, response):
        """
        Private method to process a pull response of the 'ollama' server.

        @param response dictionary containing the pull response
        @type dict
        """
        with contextlib.suppress(KeyError):
            status = response["status"]
            idStr = response.get("digest", "")[:20]
            total = response.get("total", 0)
            completed = response.get("completed", 0)
            self.pullStatus.emit(status, idStr, total, completed)

    def remove(self, model):
        """
        Public method to ask the 'ollama' server to delete the given model.

        @param model name of the model
        @type str
        """
        # TODO: not implemented yet
        ollamaRequest = {
            "name": model,
        }
        self.__sendRequest("delete", data=ollamaRequest)

    def list(self):
        """
        Public method to request a list of models available locally from the 'ollama'
        server.
        """
        self.__sendRequest("tags", processResponse=self.__processModelsList)

    def __processModelsList(self, response):
        """
        Private method to process the tags response of the 'ollama' server.

        @param response dictionary containing the tags response
        @type dict
        """
        models = []
        with contextlib.suppress(KeyError):
            for model in response["models"]:
                name = model["name"]
                if name:
                    models.append(name.replace(":latest", ""))
        self.modelsList.emit(models)

    def listDetails(self):
        """
        Public method to request a list of models available locally from the 'ollama'
        server with some model details.
        """
        # TODO: not implemented yet
        self.__sendRequest("tags", processResponse=self.__processDetailedModelsList)

    def __processDetailedModelsList(self, response):
        """
        Private method to process the tags response of the 'ollama' server extracting
        some model details.

        @param response dictionary containing the tags response
        @type dict
        """
        models = []
        with contextlib.suppress(KeyError):
            for model in response["models"]:
                name = model["name"]
                if name:
                    models.append(
                        {
                            "name": name,
                            "id": model["digest"][:20],  # first 20 characters only
                            "size": model["size"],
                            "modified": datetime.datetime.fromisoformat(
                                model["modified_at"]
                            ),
                        }
                    )
        self.detailedModelsList.emit(models)

    def listRunning(self):
        """
        Public method to request a list of running models from the 'ollama' server.
        """
        # TODO: not implemented yet
        self.__sendRequest("ps", processResponse=self.__processRunningModelsList)

    def __processRunningModelsList(self, response):
        """
        Private method to process the ps response of the 'ollama' server extracting
        some model execution details.

        @param response dictionary containing the ps response
        @type dict
        """
        models = []
        with contextlib.suppress(KeyError):
            for model in response["models"]:
                name = model["name"]
                if name:
                    if model["size_vram"] == 0:
                        processor = self.tr("100% CPU")
                    elif model["size_vram"] == model["size"]:
                        processor = self.tr("100% GPU")
                    elif model["size_vram"] > model["size_"] or model["size"] == 0:
                        processor = self.tr("unknown")
                    else:
                        sizeCpu = model["size"] - model["size_vram"]
                        cpuPercent = round(sizeCpu / model["size_vram"] * 100)
                        processor = self.tr("{0}% / {1}% CPU / GPU").format(
                            cpuPercent, 100 - cpuPercent
                        )
                    models.append(
                        {
                            "name": name,
                            "id": model["digest"][:20],  # first 20 characters only
                            "size": model["size"],
                            "size_vram": model["size_vram"],
                            "processor": processor,
                            "expires": datetime.datetime.fromisoformat(
                                model["expires_at"]
                            ),
                        }
                    )
        self.runningModelsList.emit(models)

    def version(self):
        """
        Public method to request the version from the 'ollama' server.
        """
        self.__sendRequest("version", processResponse=self.__processVersion)

    def __processVersion(self, response):
        """
        Private method to process the version response of the 'ollama' server.

        @param response dictionary containing the version response
        @type dict
        """
        with contextlib.suppress(KeyError):
            self.serverVersion.emit(response["version"])

    def state(self):
        """
        Public method to get the current client state.

        @return current client state
        @rtype OllamaClientState
        """
        return self.__state

    def __sendRequest(self, endpoint, data=None, processResponse=None):
        """
        Private method to send a request to the 'ollama' server and handle its
        responses.

        @param endpoint 'ollama' API endpoint to be contacted
        @type str
        @param data dictionary containing the data to send to the server
            (defaults to None)
        @type dict (optional)
        @param processResponse function handling the received data (defaults to None)
        @type function (optional)
        """
        self.__state = OllamaClientState.Requesting

        ollamaUrl = QUrl(
            "{0}://{1}:{2}/api/{3}".format(
                self.__plugin.getPreferences("OllamaScheme"),
                self.__plugin.getPreferences("OllamaHost"),
                self.__plugin.getPreferences("OllamaPort"),
                endpoint,
            )
        )
        request = QNetworkRequest(ollamaUrl)
        if data is not None:
            request.setHeader(
                QNetworkRequest.KnownHeaders.ContentTypeHeader, "application/json"
            )
            jsonData = json.dumps(data).encode("utf-8")
            reply = self.__networkManager.post(request, jsonData)
        else:
            reply = self.__networkManager.get(request)

        reply.finished.connect(lambda: self.__replyFinished(reply))
        reply.errorOccurred.connect(lambda error: self.__errorOccurred(error, reply))
        reply.readyRead.connect(lambda: self.__processData(reply, processResponse))
        self.__replies.append(reply)

    def __replyFinished(self, reply):
        """
        Private method to handle the finished signal of the reply.

        @param reply reference to the finished network reply object
        @type QNetworkReply
        """
        self.__state = OllamaClientState.Finished

        if reply in self.__replies:
            self.__replies.remove(reply)

        reply.deleteLater()

    def __errorOccurred(self, errorCode, reply):
        """
        Private method to handle a network error of the given reply.

        @param errorCode error code reported by the reply
        @type QNetworkReply.NetworkError
        @param reply reference to the network reply object
        @type QNetworkReply
        """
        if errorCode != QNetworkReply.NetworkError.NoError:
            self.errorOccurred.emit(
                self.tr("<p>A network error occurred.</p><p>Error: {0}</p>").format(
                    reply.errorString()
                )
            )

    def __processData(self, reply, processResponse):
        """
        Private method to receive data from the 'ollama' server and process it with a
        given processing function or method.

        @param reply reference to the network reply object
        @type QNetworkReply
        @param processResponse processing function
        @type function
        """
        self.__state = OllamaClientState.Receiving

        buffer = bytes(reply.readAll())
        if buffer:
            with contextlib.suppress(json.JSONDecodeError):
                data = json.loads(buffer)
                if data and processResponse:
                    processResponse(data)

    def heartbeat(self):
        """
        Public method to check, if the 'ollama' server has started and is responsive.

        @return flag indicating a responsive 'ollama' server
        @rtype bool
        """
        ollamaUrl = QUrl(
            "{0}://{1}:{2}/".format(
                self.__plugin.getPreferences("OllamaScheme"),
                self.__plugin.getPreferences("OllamaHost"),
                self.__plugin.getPreferences("OllamaPort"),
            )
        )
        request = QNetworkRequest(ollamaUrl)
        reply = self.__networkManager.head(request)
        while not reply.isFinished():
            QCoreApplication.processEvents()
            QThread.msleep(100)

        reply.deleteLater()

        return reply.error() == QNetworkReply.NetworkError.NoError

    @pyqtSlot()
    def __setHeartbeatTimer(self):
        """
        Private slot to configure the heartbeat timer.
        """
        interval = self.__plugin.getPreferences("OllamaHeartbeatInterval")
        if interval:
            self.__heartbeatTimer.setInterval(interval * 1000)  # interval in ms
            self.__heartbeatTimer.start()
        else:
            self.__heartbeatTimer.stop()

    @pyqtSlot()
    def __periodicHeartbeat(self):
        """
        Private slot to do a periodic check of the 'ollama' server responsiveness.
        """
        responding = self.heartbeat()
        if responding != self.__serverResponding:
            self.serverStateChanged.emit(responding)
        self.__serverResponding = responding

eric ide

mercurial