--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/OllamaInterface/OllamaClient.py Sun Aug 04 16:57:01 2024 +0200 @@ -0,0 +1,385 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2024 Detlev Offenbach <detlev@die-offenbachs.de> +# + +""" +Module implementing the 'ollama' client. +""" + +import contextlib +import datetime +import enum +import json + +from PyQt6.QtCore import pyqtSignal, QObject, QUrl +from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkRequest, QNetworkReply + +from eric7.EricNetwork.EricNetworkProxyFactory import proxyAuthenticationRequired + + +class OllamaClientState(enum.Enum): + """ + Class defining the various client states. + """ + Waiting = 0 + Requesting = 1 + Receiving = 2 + Finished = 3 + + +class OllamaClient(QObject): + """ + Class implementing the 'ollama' client. + + @signal replyReceived(content:str, role:str) emitted after a response from the + 'ollama' server was received + @signal modelsList(modelNames:list[str]) emitted after the list of model + names was obtained from the 'ollama' server + @signal detailedModelsList(models:list[dict]) emitted after the list of + models was obtained from the 'ollama' server giving some model details + @signal runningModelsList(models:list[dict]) emitted after the list of + running models was obtained from the 'ollama' server giving some model + execution details + @signal pullStatus(msg:str, id:str, total:int, completed:int) emitted to indicate + the status of a pull request as reported by the 'ollama' server + @signal finished() emitted to indicate the completion of a request + @signal errorOccurred(error:str) emitted to indicate a network error occurred + while processing the request + """ + + replyReceived = pyqtSignal(str, str) + modelsList = pyqtSignal(list) + detailedModelsList = pyqtSignal(list) + runningModelsList = pyqtSignal(list) + pullStatus = pyqtSignal(str, str, int, int) + finished = pyqtSignal() + errorOccurred = pyqtSignal(str) + + def __init__(self, plugin, parent=None): + """ + Constructor + + @param plugin reference to the plugin object + @type PluginOllamaInterface + @param parent reference to the parent object (defaults to None) + @type QObject (optional) + """ + super().__init__(parent=parent) + + self.__plugin = plugin + self.__replies = [] + + self.__networkManager = QNetworkAccessManager(self) + self.__networkManager.proxyAuthenticationRequired.connect( + proxyAuthenticationRequired + ) + + self.__state = OllamaClientState.Waiting + + def chat(self, model, messages): + """ + Public method to request a chat completion from the 'ollama' server. + + @param model name of the model to be used + @type str + @param messages list of message objects + @type list of dict + """ + # TODO: not implemented yet + ollamaRequest = { + "model": model, + "messages": messages, + } + self.__sendRequest( + "chat", data=ollamaRequest, processResponse=self.__processChatResponse + ) + + def __processChatResponse(self, response): + """ + Private method to process the chat response of the 'ollama' server. + + @param response dictionary containing the chat response + @type dict + """ + with contextlib.suppress(KeyError): + message = response["message"] + if message: + self.replyReceived.emit(message["content"], message["role"]) + + def generate(self, model, prompt, suffix=None): + """ + Public method to request to generate a completion from the 'ollama' server. + + @param model name of the model to be used + @type str + @param prompt prompt to generate a response for + @type str + @param suffix text after the model response (defaults to None) + @type str (optional) + """ + # TODO: not implemented yet + ollamaRequest = { + "model": model, + "prompt": prompt, + } + if suffix is not None: + ollamaRequest["suffix"] = suffix + self.__sendRequest( + "generate", + data=ollamaRequest, + processResponse=self.__processGenerateResponse, + ) + + def __processGenerateResponse(self, response): + """ + Private method to process the generate response of the 'ollama' server. + + @param response dictionary containing the generate response + @type dict + """ + with contextlib.suppress(KeyError): + self.replyReceived.emit(response["response"], "") + + def pull(self, model): + """ + Public method to ask the 'ollama' server to pull the given model. + + @param model name of the model + @type str + """ + # TODO: not implemented yet + ollamaRequest = { + "name": model, + } + self.__sendRequest( + "pull", data=ollamaRequest, processResponse=self.__processPullResponse + ) + + def __processPullResponse(self, response): + """ + Private method to process a pull response of the 'ollama' server. + + @param response dictionary containing the pull response + @type dict + """ + with contextlib.suppress(KeyError): + status = response["status"] + idStr = response.get("digest", "")[:20] + total = response.get("total", 0) + completed = response.get("completed", 0) + self.pullStatus.emit(status, idStr, total, completed) + + def remove(self, model): + """ + Public method to ask the 'ollama' server to delete the given model. + + @param model name of the model + @type str + """ + # TODO: not implemented yet + ollamaRequest = { + "name": model, + } + self.__sendRequest("delete", data=ollamaRequest) + + def list(self): + """ + Public method to request a list of models available locally from the 'ollama' + server. + """ + # TODO: not implemented yet + self.__sendRequest("tags", processResponse=self.__processModelsList) + + def __processModelsList(self, response): + """ + Private method to process the tags response of the 'ollama' server. + + @param response dictionary containing the tags response + @type dict + """ + models = [] + with contextlib.suppress(KeyError): + for model in response["models"]: + name = model["name"] + if name: + models.append(name) + self.modelsList.emit(models) + + def listDetails(self): + """ + Public method to request a list of models available locally from the 'ollama' + server with some model details. + + @return list of available models + @rtype list of dict + """ + # TODO: not implemented yet + self.__sendRequest("tags", processResponse=self.__processDetailedModelsList) + + def __processDetailedModelsList(self, response): + """ + Private method to process the tags response of the 'ollama' server extracting + some model details. + + @param response dictionary containing the tags response + @type dict + """ + models = [] + with contextlib.suppress(KeyError): + for model in response["models"]: + name = model["name"] + if name: + models.append( + { + "name": name, + "id": model["digest"][:20], # first 20 characters only + "size": model["size"], + "modified": datetime.datetime.fromisoformat( + model["modified_at"] + ), + } + ) + self.detailedModelsList.emit(models) + + def listRunning(self): + """ + Public method to request a list of running models from the 'ollama' server + """ + # TODO: not implemented yet + self.__sendRequest("ps", processResponse=self.__processRunningModelsList) + + def __processRunningModelsList(self, response): + """ + Private method to process the ps response of the 'ollama' server extracting + some model execution details. + + @param response dictionary containing the ps response + @type dict + """ + models = [] + with contextlib.suppress(KeyError): + for model in response["models"]: + name = model["name"] + if name: + if model["size_vram"] == 0: + processor = self.tr("100% CPU") + elif model["size_vram"] == model["size"]: + processor = self.tr("100% GPU") + elif model["size_vram"] > model["size_"] or model["size"] == 0: + processor = self.tr("unknown") + else: + sizeCpu = model["size"] - model["size_vram"] + cpuPercent = round(sizeCpu / model["size_vram"] * 100) + processor = self.tr("{0}% / {1}% CPU / GPU").format( + cpuPercent, 100 - cpuPercent + ) + models.append( + { + "name": name, + "id": model["digest"][:20], # first 20 characters only + "size": model["size"], + "size_vram": model["size_vram"], + "processor": processor, + "expires": datetime.datetime.fromisoformat( + model["expires_at"] + ), + } + ) + self.runningModelsList.emit(models) + + def state(self): + """ + Public method to get the current client state. + + @return current client state + @rtype OllamaClientState + """ + return self.__state + + def __sendRequest(self, endpoint, data=None, processResponse=None): + """ + Private method to send a request to the 'ollama' server and handle its + responses. + + @param endpoint 'ollama' API endpoint to be contacted + @type str + @param data dictionary containing the data to send to the server + (defaults to None) + @type dict (optional) + @param processResponse function handling the received data (defaults to None) + @type function (optional) + """ + self.__state = OllamaClientState.Requesting + + ollamaUrl = QUrl( + "{0}://{1}:{2}/api/{3}".format( + self.__plugin.getPreferences("OllamaScheme"), + self.__plugin.getPreferences("OllamaHost"), + self.__plugin.getPreferences("OllamaPort"), + endpoint, + ) + ) + request = QNetworkRequest(ollamaUrl) + if data is not None: + request.setHeader( + QNetworkRequest.KnownHeaders.ContentTypeHeader, "application/json" + ) + jsonData = json.dumps(data).encode("utf-8") + reply = self.__networkManager.post(request=request, data=jsonData) + else: + reply = self.__networkManager.get(request=request) + + reply.finished.connect(lambda: self.__replyFinished(reply)) + reply.errorOccurred.connect(lambda error: self.__errorOccurred(error, reply)) + reply.readyRead.connect(lambda: self.__processData(reply, processResponse)) + self.__replies.append(reply) + + def __replyFinished(self, reply): + """ + Private method to handle the finished signal of the reply. + + @param reply reference to the finished network reply object + @type QNetworkReply + """ + self.__state = OllamaClientState.Finished + + if reply in self.__replies: + self.__replies.remove(reply) + + reply.deleteLater() + + def __errorOccurred(self, errorCode, reply): + """ + Private method to handle a network error of the given reply. + + @param errorCode error code reported by the reply + @type QNetworkReply.NetworkError + @param reply reference to the network reply object + @type QNetworkReply + """ + if errorCode != QNetworkReply.NetworkError.NoError: + self.errorOccurred.emit( + self.tr("<p>A network error occurred.</p><p>Error: {0}</p>").format( + reply.errorString() + ) + ) + + def __processData(self, reply, processResponse): + """ + Private method to receive data from the 'ollama' server and process it with a + given processing function or method. + + @param reply reference to the network reply object + @type QNetworkReply + @param processResponse processing function + @type function + """ + self.__state = OllamaClientState.Receiving + + buffer = bytes(reply.readAll()) + if buffer: + with contextlib.suppress(json.JSONDecodeError): + data = json.loads(buffer) + if data and processResponse: + processResponse(data)