OllamaInterface/OllamaClient.py

Mon, 07 Apr 2025 18:22:30 +0200

author
Detlev Offenbach <detlev@die-offenbachs.de>
date
Mon, 07 Apr 2025 18:22:30 +0200
changeset 69
eb9340034f26
parent 67
3c2bcbf7eeaf
permissions
-rw-r--r--

Created global tag <release-10.1.8>.

# -*- coding: utf-8 -*-

# Copyright (c) 2024 - 2025 Detlev Offenbach <detlev@die-offenbachs.de>
#

"""
Module implementing the 'ollama' client.
"""

import contextlib
import datetime
import enum
import json

from PyQt6.QtCore import (
    QCoreApplication,
    QObject,
    QThread,
    QTimer,
    QUrl,
    pyqtSignal,
    pyqtSlot,
)
from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkReply, QNetworkRequest

from eric7.EricNetwork.EricNetworkProxyFactory import proxyAuthenticationRequired


class OllamaClientState(enum.Enum):
    """
    Class defining the various client states.
    """

    Waiting = 0
    Requesting = 1
    Receiving = 2
    Finished = 3


class OllamaClient(QObject):
    """
    Class implementing the 'ollama' client.

    @signal replyReceived(content:str, role:str, done:bool) emitted after a response
        from the 'ollama' server was received
    @signal modelsList(modelNames:list[str]) emitted after the list of model
        names was obtained from the 'ollama' server
    @signal pullStatus(msg:str, id:str, total:int, completed:int) emitted to indicate
        the status of a pull request as reported by the 'ollama' server
    @signal pullError(msg:str) emitted to indicate an error during a pull operation
    @signal serverVersion(version:str) emitted after the server version was obtained
        from the 'ollama' server
    @signal finished() emitted to indicate the completion of a request
    @signal errorOccurred(error:str) emitted to indicate a network error occurred
        while processing the request
    @signal serverStateChanged(ok:bool, msg:str) emitted to indicate a change of the
        server responsiveness
    """

    replyReceived = pyqtSignal(str, str, bool)
    modelsList = pyqtSignal(list)
    pullStatus = pyqtSignal(str, str, "unsigned long int", "unsigned long int")
    pullError = pyqtSignal(str)
    serverVersion = pyqtSignal(str)
    finished = pyqtSignal()
    errorOccurred = pyqtSignal(str)
    serverStateChanged = pyqtSignal(bool, str)

    def __init__(self, plugin, parent=None):
        """
        Constructor

        @param plugin reference to the plugin object
        @type PluginOllamaInterface
        @param parent reference to the parent object (defaults to None)
        @type QObject (optional)
        """
        super().__init__(parent=parent)

        self.__plugin = plugin
        self.__replies = []
        self.__pullReply = None

        self.__networkManager = QNetworkAccessManager(self)
        self.__networkManager.proxyAuthenticationRequired.connect(
            proxyAuthenticationRequired
        )

        self.__serverResponding = None  # start with an unknown state
        self.__heartbeatTimer = QTimer(self)
        self.__heartbeatTimer.timeout.connect(self.__periodicHeartbeat)

        self.__state = OllamaClientState.Waiting
        self.__localServer = False

        self.__plugin.preferencesChanged.connect(self.__setHeartbeatTimer)
        QTimer.singleShot(0, self.__setHeartbeatTimer)

    def setMode(self, local):
        """
        Public method to set the client mode to local.

        @param local flag indicating to connect to a locally started ollama server
        @type bool
        """
        self.__localServer = local
        self.__serverResponding = None
        if not self.__plugin.getPreferences("OllamaHeartbeatInterval"):
            # schedule one heartbeat check giving local server some time to start
            QTimer.singleShot(
                10 * 1000 if self.__localServer else 0,
                self.__periodicHeartbeat,
            )

    def chat(self, model, messages, streaming=True):
        """
        Public method to request a chat completion from the 'ollama' server.

        @param model name of the model to be used
        @type str
        @param messages list of message objects
        @type list of dict
        @param streaming flag indicating to receive a streaming response
        @type bool
        """
        ollamaRequest = {
            "model": model,
            "messages": messages,
            "stream": streaming,
        }
        self.__sendRequest(
            "chat", data=ollamaRequest, processResponse=self.__processChatResponse
        )

    def __processChatResponse(self, response):
        """
        Private method to process the chat response of the 'ollama' server.

        @param response dictionary containing the chat response
        @type dict
        """
        with contextlib.suppress(KeyError):
            message = response["message"]
            done = response["done"]
            if message:
                self.replyReceived.emit(message["content"], message["role"], done)

    def generate(self, model, prompt, suffix=None):
        """
        Public method to request to generate a completion from the 'ollama' server.

        @param model name of the model to be used
        @type str
        @param prompt prompt to generate a response for
        @type str
        @param suffix text after the model response (defaults to None)
        @type str (optional)
        """
        ollamaRequest = {
            "model": model,
            "prompt": prompt,
        }
        if suffix is not None:
            ollamaRequest["suffix"] = suffix
        self.__sendRequest(
            "generate",
            data=ollamaRequest,
            processResponse=self.__processGenerateResponse,
        )

    def __processGenerateResponse(self, response):
        """
        Private method to process the generate response of the 'ollama' server.

        @param response dictionary containing the generate response
        @type dict
        """
        with contextlib.suppress(KeyError):
            self.replyReceived.emit(response["response"], "", response["done"])

    def pull(self, model):
        """
        Public method to ask the 'ollama' server to pull the given model.

        @param model name of the model
        @type str
        """
        ollamaRequest = {
            "model": model,
        }
        self.__sendRequest(
            "pull", data=ollamaRequest, processResponse=self.__processPullResponse
        )

    def __processPullResponse(self, response):
        """
        Private method to process a pull response of the 'ollama' server.

        @param response dictionary containing the pull response
        @type dict
        """
        if "error" in response:
            self.pullError.emit(response["error"])
        else:
            with contextlib.suppress(KeyError):
                status = response["status"]
                idStr = response.get("digest", "")[:20]
                total = response.get("total", 0)
                completed = response.get("completed", 0)
                self.pullStatus.emit(status, idStr, total, completed)

    def abortPull(self):
        """
        Public method to abort an ongoing pull operation.
        """
        if self.__pullReply is not None:
            self.__pullReply.close()

    def remove(self, model):
        """
        Public method to ask the 'ollama' server to delete the given model.

        @param model name of the model
        @type str
        @return flag indicating success
        @rtype bool
        """
        ollamaRequest = {
            "model": model,
        }
        _, status = self.__sendSyncRequest("delete", data=ollamaRequest, delete=True)
        return status == 200  # HTTP status 200 OK

    def list(self):
        """
        Public method to request a list of models available locally from the 'ollama'
        server.
        """
        self.__sendRequest("tags", processResponse=self.__processModelsList)

    def __processModelsList(self, response):
        """
        Private method to process the tags response of the 'ollama' server.

        @param response dictionary containing the tags response
        @type dict
        """
        models = []
        with contextlib.suppress(KeyError):
            for model in response["models"]:
                name = model["name"]
                if name:
                    models.append(name)
        self.modelsList.emit(models)

    def listDetails(self):
        """
        Public method to request a list of models available locally from the 'ollama'
        server with some model details.

        @return list of dictionaries containing the available models and related data
        @rtype list[dict[str, Any]]
        """
        response, _ = self.__sendSyncRequest("tags")

        models = []
        if response is not None:
            with contextlib.suppress(KeyError):
                for model in response["models"]:
                    name = model["name"]
                    if name:
                        models.append(
                            {
                                "name": name,
                                "id": model["digest"][:20],  # first 20 characters only
                                "size": model["size"],
                                "modified": datetime.datetime.fromisoformat(
                                    model["modified_at"]
                                ),
                            }
                        )

        return models

    def listRunning(self):
        """
        Public method to request a list of running models from the 'ollama' server.

        @return list of dictionaries containing the running models and related data
        @rtype list[dict[str, Any]]
        """
        response, _ = self.__sendSyncRequest("ps")

        models = []
        if response is not None:
            with contextlib.suppress(KeyError):
                for model in response["models"]:
                    name = model["name"]
                    if name:
                        if model["size_vram"] == 0:
                            processor = self.tr("100% CPU")
                        elif model["size_vram"] == model["size"]:
                            processor = self.tr("100% GPU")
                        elif model["size_vram"] > model["size"] or model["size"] == 0:
                            processor = self.tr("unknown")
                        else:
                            sizeCpu = model["size"] - model["size_vram"]
                            cpuPercent = round(sizeCpu / model["size_vram"] * 100)
                            processor = self.tr("{0}% / {1}% CPU / GPU").format(
                                cpuPercent, 100 - cpuPercent
                            )
                        models.append(
                            {
                                "name": name,
                                "id": model["digest"][:20],  # first 20 characters only
                                "size": model["size"],
                                "size_vram": model["size_vram"],
                                "processor": processor,
                                "expires": datetime.datetime.fromisoformat(
                                    model["expires_at"]
                                ),
                            }
                        )

        return models

    def version(self):
        """
        Public method to request the version from the 'ollama' server.
        """
        self.__sendRequest("version", processResponse=self.__processVersion)

    def __processVersion(self, response):
        """
        Private method to process the version response of the 'ollama' server.

        @param response dictionary containing the version response
        @type dict
        """
        with contextlib.suppress(KeyError):
            self.serverVersion.emit(response["version"])
            if (
                self.__plugin.getPreferences("OllamaHeartbeatInterval") == 0
                and not self.__serverResponding
            ):
                # implicit connectivity check success
                self.__serverResponding = True
                self.serverStateChanged.emit(True, "")

    def state(self):
        """
        Public method to get the current client state.

        @return current client state
        @rtype OllamaClientState
        """
        return self.__state

    def __getServerReply(self, endpoint, data=None, delete=False):
        """
        Private method to send a request to the 'ollama' server and return a reply
        object.

        @param endpoint 'ollama' API endpoint to be contacted
        @type str
        @param data dictionary containing the data to send to the server
            (defaults to None)
        @type dict (optional)
        @param delete flag indicating to send a delete request (defaults to False)
        @type bool (optional)
        @return 'ollama' server reply
        @rtype QNetworkReply
        """
        ollamaUrl = (
            QUrl(
                "http://127.0.0.1:{0}/api/{1}".format(
                    self.__plugin.getPreferences("OllamaLocalPort"),
                    endpoint,
                )
            )
            if self.__localServer
            else QUrl(
                "{0}://{1}:{2}/api/{3}".format(
                    self.__plugin.getPreferences("OllamaScheme"),
                    self.__plugin.getPreferences("OllamaHost"),
                    self.__plugin.getPreferences("OllamaPort"),
                    endpoint,
                )
            )
        )
        request = QNetworkRequest(ollamaUrl)
        if data is not None:
            request.setHeader(
                QNetworkRequest.KnownHeaders.ContentTypeHeader, "application/json"
            )
            jsonData = json.dumps(data).encode("utf-8")
            if delete:
                reply = self.__networkManager.sendCustomRequest(
                    request, b"DELETE", jsonData
                )
            else:
                reply = self.__networkManager.post(request, jsonData)
        else:
            reply = self.__networkManager.get(request)
        reply.errorOccurred.connect(lambda error: self.__errorOccurred(error, reply))
        return reply

    def __sendRequest(self, endpoint, data=None, processResponse=None):
        """
        Private method to send a request to the 'ollama' server and handle its
        responses.

        @param endpoint 'ollama' API endpoint to be contacted
        @type str
        @param data dictionary containing the data to send to the server
            (defaults to None)
        @type dict (optional)
        @param processResponse function handling the received data (defaults to None)
        @type function (optional)
        """
        self.__state = OllamaClientState.Requesting

        reply = self.__getServerReply(endpoint=endpoint, data=data)
        reply.finished.connect(lambda: self.__replyFinished(reply))
        reply.readyRead.connect(lambda: self.__processData(reply, processResponse))
        if endpoint == "pull":
            self.__pullReply = reply
        else:
            self.__replies.append(reply)

    def __replyFinished(self, reply):
        """
        Private method to handle the finished signal of the reply.

        @param reply reference to the finished network reply object
        @type QNetworkReply
        """
        self.__state = OllamaClientState.Finished

        if reply == self.__pullReply:
            self.__pullReply = None
        elif reply in self.__replies:
            self.__replies.remove(reply)

        reply.deleteLater()

        self.finished.emit()

    def __errorOccurred(self, errorCode, reply):
        """
        Private method to handle a network error of the given reply.

        @param errorCode error code reported by the reply
        @type QNetworkReply.NetworkError
        @param reply reference to the network reply object
        @type QNetworkReply
        """
        if errorCode == QNetworkReply.NetworkError.ConnectionRefusedError:
            self.serverStateChanged.emit(False, self.__serverNotRespondingMessage())
        elif errorCode not in (
            QNetworkReply.NetworkError.NoError,
            QNetworkReply.NetworkError.OperationCanceledError,
        ):
            self.errorOccurred.emit(
                self.tr("<p>A network error occurred.</p><p>Error: {0}</p>").format(
                    reply.errorString()
                )
            )

    def __processData(self, reply, processResponse):
        """
        Private method to receive data from the 'ollama' server and process it with a
        given processing function or method.

        @param reply reference to the network reply object
        @type QNetworkReply
        @param processResponse processing function
        @type function
        """
        self.__state = OllamaClientState.Receiving

        buffer = bytes(reply.readAll())
        if buffer:
            with contextlib.suppress(json.JSONDecodeError):
                data = json.loads(buffer)
                if data and processResponse:
                    processResponse(data)

    def __sendSyncRequest(self, endpoint, data=None, delete=False):
        """
        Private method to send a request to the 'ollama' server and handle its
        responses.

        @param endpoint 'ollama' API endpoint to be contacted
        @type str
        @param data dictionary containing the data to send to the server
            (defaults to None)
        @type dict (optional)
        @param delete flag indicating to send a delete request (defaults to False)
        @type bool (optional)
        @return tuple containing the data sent by the 'ollama' server and the HTTP
            status code
        @rtype tuple of (Any, int)
        """
        self.__state = OllamaClientState.Requesting

        reply = self.__getServerReply(endpoint=endpoint, data=data, delete=delete)
        while not reply.isFinished():
            QCoreApplication.processEvents()
            QThread.msleep(100)

        reply.deleteLater()

        self.__state = OllamaClientState.Finished

        statusCode = reply.attribute(QNetworkRequest.Attribute.HttpStatusCodeAttribute)

        if reply.error() == QNetworkReply.NetworkError.NoError:
            buffer = bytes(reply.readAll())
            with contextlib.suppress(json.JSONDecodeError):
                data = json.loads(buffer)
                return data, statusCode

        return None, statusCode

    def __getHeartbeatUrl(self):
        """
        Private method to get the current heartbeat URL.

        @return URL to be contacted by the heartbeat check
        @rtype str
        """
        return (
            "http://127.0.0.1:{0}".format(
                self.__plugin.getPreferences("OllamaLocalPort"),
            )
            if self.__localServer
            else "{0}://{1}:{2}/".format(
                self.__plugin.getPreferences("OllamaScheme"),
                self.__plugin.getPreferences("OllamaHost"),
                self.__plugin.getPreferences("OllamaPort"),
            )
        )

    def heartbeat(self):
        """
        Public method to check, if the 'ollama' server has started and is responsive.

        @return flag indicating a responsive 'ollama' server
        @rtype bool
        """
        request = QNetworkRequest(QUrl(self.__getHeartbeatUrl()))
        reply = self.__networkManager.head(request)
        while not reply.isFinished():
            QCoreApplication.processEvents()
            QThread.msleep(100)

        reply.deleteLater()

        return reply.error() == QNetworkReply.NetworkError.NoError

    @pyqtSlot()
    def __setHeartbeatTimer(self):
        """
        Private slot to configure the heartbeat timer.
        """
        interval = self.__plugin.getPreferences("OllamaHeartbeatInterval")
        if interval:
            self.__heartbeatTimer.setInterval(interval * 1000)  # interval in ms
            self.__heartbeatTimer.start()
        else:
            self.__heartbeatTimer.stop()

        # do one initial heartbeat
        self.__serverResponding = None
        self.__periodicHeartbeat()

    @pyqtSlot()
    def __periodicHeartbeat(self):
        """
        Private slot to do a periodic check of the 'ollama' server responsiveness.
        """
        responding = self.heartbeat()
        if responding != self.__serverResponding:
            msg = "" if responding else self.__serverNotRespondingMessage()
            self.serverStateChanged.emit(responding, msg)
        self.__serverResponding = responding

    def __serverNotRespondingMessage(self):
        """
        Private method to assemble and return a message for a non-responsive server.

        @return error message
        @rtype str
        """
        return (
            self.tr("<p>Error: The local server at <b>{0}</b> is not responding.</p>")
            if self.__localServer
            else self.tr(
                "<p>Error: The configured server at <b>{0}</b> is not"
                " responding.</p>"
            )
        ).format(self.__getHeartbeatUrl())

eric ide

mercurial