OllamaInterface/OllamaClient.py

Sun, 04 Aug 2024 16:57:01 +0200

author
Detlev Offenbach <detlev@die-offenbachs.de>
date
Sun, 04 Aug 2024 16:57:01 +0200
changeset 3
ca28466a186d
child 4
7dd1b9cd3150
permissions
-rw-r--r--

Implemented the ollama client object (not tested yet).

# -*- coding: utf-8 -*-

# Copyright (c) 2024 Detlev Offenbach <detlev@die-offenbachs.de>
#

"""
Module implementing the 'ollama' client.
"""

import contextlib
import datetime
import enum
import json

from PyQt6.QtCore import pyqtSignal, QObject, QUrl
from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkRequest, QNetworkReply

from eric7.EricNetwork.EricNetworkProxyFactory import proxyAuthenticationRequired


class OllamaClientState(enum.Enum):
    """
    Class defining the various client states.
    """
    Waiting = 0
    Requesting = 1
    Receiving = 2
    Finished = 3


class OllamaClient(QObject):
    """
    Class implementing the 'ollama' client.

    @signal replyReceived(content:str, role:str) emitted after a response from the 
        'ollama' server was received
    @signal modelsList(modelNames:list[str]) emitted after the list of model
        names was obtained from the 'ollama' server
    @signal detailedModelsList(models:list[dict]) emitted after the list of
        models was obtained from the 'ollama' server giving some model details
    @signal runningModelsList(models:list[dict]) emitted after the list of
        running models was obtained from the 'ollama' server giving some model
        execution details
    @signal pullStatus(msg:str, id:str, total:int, completed:int) emitted to indicate
        the status of a pull request as reported by the 'ollama' server
    @signal finished() emitted to indicate the completion of a request
    @signal errorOccurred(error:str) emitted to indicate a network error occurred
        while processing the request
    """

    replyReceived = pyqtSignal(str, str)
    modelsList = pyqtSignal(list)
    detailedModelsList = pyqtSignal(list)
    runningModelsList = pyqtSignal(list)
    pullStatus = pyqtSignal(str, str, int, int)
    finished = pyqtSignal()
    errorOccurred = pyqtSignal(str)

    def __init__(self, plugin, parent=None):
        """
        Constructor

        @param plugin reference to the plugin object
        @type PluginOllamaInterface
        @param parent reference to the parent object (defaults to None)
        @type QObject (optional)
        """
        super().__init__(parent=parent)

        self.__plugin = plugin
        self.__replies = []

        self.__networkManager = QNetworkAccessManager(self)
        self.__networkManager.proxyAuthenticationRequired.connect(
            proxyAuthenticationRequired
        )

        self.__state = OllamaClientState.Waiting

    def chat(self, model, messages):
        """
        Public method to request a chat completion from the 'ollama' server.

        @param model name of the model to be used
        @type str
        @param messages list of message objects
        @type list of dict
        """
        # TODO: not implemented yet
        ollamaRequest = {
            "model": model,
            "messages": messages,
        }
        self.__sendRequest(
            "chat", data=ollamaRequest, processResponse=self.__processChatResponse
        )

    def __processChatResponse(self, response):
        """
        Private method to process the chat response of the 'ollama' server.

        @param response dictionary containing the chat response
        @type dict
        """
        with contextlib.suppress(KeyError):
            message = response["message"]
            if message:
                self.replyReceived.emit(message["content"], message["role"])

    def generate(self, model, prompt, suffix=None):
        """
        Public method to request to generate a completion from the 'ollama' server.

        @param model name of the model to be used
        @type str
        @param prompt prompt to generate a response for
        @type str
        @param suffix text after the model response (defaults to None)
        @type str (optional)
        """
        # TODO: not implemented yet
        ollamaRequest = {
            "model": model,
            "prompt": prompt,
        }
        if suffix is not None:
            ollamaRequest["suffix"] = suffix
        self.__sendRequest(
            "generate",
            data=ollamaRequest,
            processResponse=self.__processGenerateResponse,
        )

    def __processGenerateResponse(self, response):
        """
        Private method to process the generate response of the 'ollama' server.

        @param response dictionary containing the generate response
        @type dict
        """
        with contextlib.suppress(KeyError):
            self.replyReceived.emit(response["response"], "")

    def pull(self, model):
        """
        Public method to ask the 'ollama' server to pull the given model.

        @param model name of the model
        @type str
        """
        # TODO: not implemented yet
        ollamaRequest = {
            "name": model,
        }
        self.__sendRequest(
            "pull", data=ollamaRequest, processResponse=self.__processPullResponse
        )

    def __processPullResponse(self, response):
        """
        Private method to process a pull response of the 'ollama' server.

        @param response dictionary containing the pull response
        @type dict
        """
        with contextlib.suppress(KeyError):
            status = response["status"]
            idStr = response.get("digest", "")[:20]
            total = response.get("total", 0)
            completed = response.get("completed", 0)
            self.pullStatus.emit(status, idStr, total, completed)

    def remove(self, model):
        """
        Public method to ask the 'ollama' server to delete the given model.

        @param model name of the model
        @type str
        """
        # TODO: not implemented yet
        ollamaRequest = {
            "name": model,
        }
        self.__sendRequest("delete", data=ollamaRequest)

    def list(self):
        """
        Public method to request a list of models available locally from the 'ollama'
        server.
        """
        # TODO: not implemented yet
        self.__sendRequest("tags", processResponse=self.__processModelsList)

    def __processModelsList(self, response):
        """
        Private method to process the tags response of the 'ollama' server.

        @param response dictionary containing the tags response
        @type dict
        """
        models = []
        with contextlib.suppress(KeyError):
            for model in response["models"]:
                name = model["name"]
                if name:
                    models.append(name)
        self.modelsList.emit(models)

    def listDetails(self):
        """
        Public method to request a list of models available locally from the 'ollama'
        server with some model details.

        @return list of available models
        @rtype list of dict
        """
        # TODO: not implemented yet
        self.__sendRequest("tags", processResponse=self.__processDetailedModelsList)

    def __processDetailedModelsList(self, response):
        """
        Private method to process the tags response of the 'ollama' server extracting
        some model details.

        @param response dictionary containing the tags response
        @type dict
        """
        models = []
        with contextlib.suppress(KeyError):
            for model in response["models"]:
                name = model["name"]
                if name:
                    models.append(
                        {
                            "name": name,
                            "id": model["digest"][:20],  # first 20 characters only
                            "size": model["size"],
                            "modified": datetime.datetime.fromisoformat(
                                model["modified_at"]
                            ),
                        }
                    )
        self.detailedModelsList.emit(models)

    def listRunning(self):
        """
        Public method to request a list of running models from the 'ollama' server
        """
        # TODO: not implemented yet
        self.__sendRequest("ps", processResponse=self.__processRunningModelsList)

    def __processRunningModelsList(self, response):
        """
        Private method to process the ps response of the 'ollama' server extracting
        some model execution details.

        @param response dictionary containing the ps response
        @type dict
        """
        models = []
        with contextlib.suppress(KeyError):
            for model in response["models"]:
                name = model["name"]
                if name:
                    if model["size_vram"] == 0:
                        processor = self.tr("100% CPU")
                    elif model["size_vram"] == model["size"]:
                        processor = self.tr("100% GPU")
                    elif model["size_vram"] > model["size_"] or model["size"] == 0:
                        processor = self.tr("unknown")
                    else:
                        sizeCpu = model["size"] - model["size_vram"]
                        cpuPercent = round(sizeCpu / model["size_vram"] * 100)
                        processor = self.tr("{0}% / {1}% CPU / GPU").format(
                            cpuPercent, 100 - cpuPercent
                        )
                    models.append(
                        {
                            "name": name,
                            "id": model["digest"][:20],  # first 20 characters only
                            "size": model["size"],
                            "size_vram": model["size_vram"],
                            "processor": processor,
                            "expires": datetime.datetime.fromisoformat(
                                model["expires_at"]
                            ),
                        }
                    )
        self.runningModelsList.emit(models)

    def state(self):
        """
        Public method to get the current client state.
        
        @return current client state
        @rtype OllamaClientState
        """
        return self.__state

    def __sendRequest(self, endpoint, data=None, processResponse=None):
        """
        Private method to send a request to the 'ollama' server and handle its
        responses.

        @param endpoint 'ollama' API endpoint to be contacted
        @type str
        @param data dictionary containing the data to send to the server
            (defaults to None)
        @type dict (optional)
        @param processResponse function handling the received data (defaults to None)
        @type function (optional)
        """
        self.__state = OllamaClientState.Requesting

        ollamaUrl = QUrl(
            "{0}://{1}:{2}/api/{3}".format(
                self.__plugin.getPreferences("OllamaScheme"),
                self.__plugin.getPreferences("OllamaHost"),
                self.__plugin.getPreferences("OllamaPort"),
                endpoint,
            )
        )
        request = QNetworkRequest(ollamaUrl)
        if data is not None:
            request.setHeader(
                QNetworkRequest.KnownHeaders.ContentTypeHeader, "application/json"
            )
            jsonData = json.dumps(data).encode("utf-8")
            reply = self.__networkManager.post(request=request, data=jsonData)
        else:
            reply = self.__networkManager.get(request=request)

        reply.finished.connect(lambda: self.__replyFinished(reply))
        reply.errorOccurred.connect(lambda error: self.__errorOccurred(error, reply))
        reply.readyRead.connect(lambda: self.__processData(reply, processResponse))
        self.__replies.append(reply)

    def __replyFinished(self, reply):
        """
        Private method to handle the finished signal of the reply.

        @param reply reference to the finished network reply object
        @type QNetworkReply
        """
        self.__state = OllamaClientState.Finished

        if reply in self.__replies:
            self.__replies.remove(reply)

        reply.deleteLater()

    def __errorOccurred(self, errorCode, reply):
        """
        Private method to handle a network error of the given reply.

        @param errorCode error code reported by the reply
        @type QNetworkReply.NetworkError
        @param reply reference to the network reply object
        @type QNetworkReply
        """
        if errorCode != QNetworkReply.NetworkError.NoError:
            self.errorOccurred.emit(
                self.tr("<p>A network error occurred.</p><p>Error: {0}</p>").format(
                    reply.errorString()
                )
            )

    def __processData(self, reply, processResponse):
        """
        Private method to receive data from the 'ollama' server and process it with a
        given processing function or method.

        @param reply reference to the network reply object
        @type QNetworkReply
        @param processResponse processing function
        @type function
        """
        self.__state = OllamaClientState.Receiving

        buffer = bytes(reply.readAll())
        if buffer:
            with contextlib.suppress(json.JSONDecodeError):
                data = json.loads(buffer)
                if data and processResponse:
                    processResponse(data)

eric ide

mercurial