Sat, 07 Sep 2024 18:44:47 +0200
Fixed a bug introduced by the latest change.
# -*- coding: utf-8 -*- # Copyright (c) 2024 Detlev Offenbach <detlev@die-offenbachs.de> # """ Module implementing the 'ollama' client. """ import contextlib import datetime import enum import json from PyQt6.QtCore import ( QCoreApplication, QObject, QThread, QTimer, QUrl, pyqtSignal, pyqtSlot, ) from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkReply, QNetworkRequest from eric7.EricNetwork.EricNetworkProxyFactory import proxyAuthenticationRequired class OllamaClientState(enum.Enum): """ Class defining the various client states. """ Waiting = 0 Requesting = 1 Receiving = 2 Finished = 3 class OllamaClient(QObject): """ Class implementing the 'ollama' client. @signal replyReceived(content:str, role:str, done:bool) emitted after a response from the 'ollama' server was received @signal modelsList(modelNames:list[str]) emitted after the list of model names was obtained from the 'ollama' server @signal pullStatus(msg:str, id:str, total:int, completed:int) emitted to indicate the status of a pull request as reported by the 'ollama' server @signal pullError(msg:str) emitted to indicate an error during a pull operation @signal serverVersion(version:str) emitted after the server version was obtained from the 'ollama' server @signal finished() emitted to indicate the completion of a request @signal errorOccurred(error:str) emitted to indicate a network error occurred while processing the request @signal serverStateChanged(ok:bool, msg:str) emitted to indicate a change of the server responsiveness """ replyReceived = pyqtSignal(str, str, bool) modelsList = pyqtSignal(list) pullStatus = pyqtSignal(str, str, "unsigned long int", "unsigned long int") pullError = pyqtSignal(str) serverVersion = pyqtSignal(str) finished = pyqtSignal() errorOccurred = pyqtSignal(str) serverStateChanged = pyqtSignal(bool, str) def __init__(self, plugin, parent=None): """ Constructor @param plugin reference to the plugin object @type PluginOllamaInterface @param parent reference to the parent object (defaults to None) @type QObject (optional) """ super().__init__(parent=parent) self.__plugin = plugin self.__replies = [] self.__pullReply = None self.__networkManager = QNetworkAccessManager(self) self.__networkManager.proxyAuthenticationRequired.connect( proxyAuthenticationRequired ) self.__serverResponding = None # start with an unknown state self.__heartbeatTimer = QTimer(self) self.__heartbeatTimer.timeout.connect(self.__periodicHeartbeat) self.__state = OllamaClientState.Waiting self.__localServer = False self.__plugin.preferencesChanged.connect(self.__setHeartbeatTimer) QTimer.singleShot(0, self.__setHeartbeatTimer) def setMode(self, local): """ Public method to set the client mode to local. @param local flag indicating to connect to a locally started ollama server @type bool """ self.__localServer = local self.__serverResponding = None if not self.__plugin.getPreferences("OllamaHeartbeatInterval"): # schedule one heartbeat check giving local server some time to start QTimer.singleShot( 10 * 1000 if self.__localServer else 0, self.__periodicHeartbeat, ) def chat(self, model, messages, streaming=True): """ Public method to request a chat completion from the 'ollama' server. @param model name of the model to be used @type str @param messages list of message objects @type list of dict @param streaming flag indicating to receive a streaming response @type bool """ ollamaRequest = { "model": model, "messages": messages, "stream": streaming, } self.__sendRequest( "chat", data=ollamaRequest, processResponse=self.__processChatResponse ) def __processChatResponse(self, response): """ Private method to process the chat response of the 'ollama' server. @param response dictionary containing the chat response @type dict """ with contextlib.suppress(KeyError): message = response["message"] done = response["done"] if message: self.replyReceived.emit(message["content"], message["role"], done) def generate(self, model, prompt, suffix=None): """ Public method to request to generate a completion from the 'ollama' server. @param model name of the model to be used @type str @param prompt prompt to generate a response for @type str @param suffix text after the model response (defaults to None) @type str (optional) """ ollamaRequest = { "model": model, "prompt": prompt, } if suffix is not None: ollamaRequest["suffix"] = suffix self.__sendRequest( "generate", data=ollamaRequest, processResponse=self.__processGenerateResponse, ) def __processGenerateResponse(self, response): """ Private method to process the generate response of the 'ollama' server. @param response dictionary containing the generate response @type dict """ with contextlib.suppress(KeyError): self.replyReceived.emit(response["response"], "", response["done"]) def pull(self, model): """ Public method to ask the 'ollama' server to pull the given model. @param model name of the model @type str """ ollamaRequest = { "model": model, } self.__sendRequest( "pull", data=ollamaRequest, processResponse=self.__processPullResponse ) def __processPullResponse(self, response): """ Private method to process a pull response of the 'ollama' server. @param response dictionary containing the pull response @type dict """ if "error" in response: self.pullError.emit(response["error"]) else: with contextlib.suppress(KeyError): status = response["status"] idStr = response.get("digest", "")[:20] total = response.get("total", 0) completed = response.get("completed", 0) self.pullStatus.emit(status, idStr, total, completed) def abortPull(self): """ Public method to abort an ongoing pull operation. """ if self.__pullReply is not None: self.__pullReply.close() def remove(self, model): """ Public method to ask the 'ollama' server to delete the given model. @param model name of the model @type str @return flag indicating success @rtype bool """ ollamaRequest = { "model": model, } _, status = self.__sendSyncRequest("delete", data=ollamaRequest, delete=True) return status == 200 # HTTP status 200 OK def list(self): """ Public method to request a list of models available locally from the 'ollama' server. """ self.__sendRequest("tags", processResponse=self.__processModelsList) def __processModelsList(self, response): """ Private method to process the tags response of the 'ollama' server. @param response dictionary containing the tags response @type dict """ models = [] with contextlib.suppress(KeyError): for model in response["models"]: name = model["name"] if name: models.append(name) self.modelsList.emit(models) def listDetails(self): """ Public method to request a list of models available locally from the 'ollama' server with some model details. @return list of dictionaries containing the available models and related data @rtype list[dict[str, Any]] """ response, _ = self.__sendSyncRequest("tags") models = [] if response is not None: with contextlib.suppress(KeyError): for model in response["models"]: name = model["name"] if name: models.append( { "name": name, "id": model["digest"][:20], # first 20 characters only "size": model["size"], "modified": datetime.datetime.fromisoformat( model["modified_at"] ), } ) return models def listRunning(self): """ Public method to request a list of running models from the 'ollama' server. @return list of dictionaries containing the running models and related data @rtype list[dict[str, Any]] """ response, _ = self.__sendSyncRequest("ps") models = [] if response is not None: with contextlib.suppress(KeyError): for model in response["models"]: name = model["name"] if name: if model["size_vram"] == 0: processor = self.tr("100% CPU") elif model["size_vram"] == model["size"]: processor = self.tr("100% GPU") elif model["size_vram"] > model["size"] or model["size"] == 0: processor = self.tr("unknown") else: sizeCpu = model["size"] - model["size_vram"] cpuPercent = round(sizeCpu / model["size_vram"] * 100) processor = self.tr("{0}% / {1}% CPU / GPU").format( cpuPercent, 100 - cpuPercent ) models.append( { "name": name, "id": model["digest"][:20], # first 20 characters only "size": model["size"], "size_vram": model["size_vram"], "processor": processor, "expires": datetime.datetime.fromisoformat( model["expires_at"] ), } ) return models def version(self): """ Public method to request the version from the 'ollama' server. """ self.__sendRequest("version", processResponse=self.__processVersion) def __processVersion(self, response): """ Private method to process the version response of the 'ollama' server. @param response dictionary containing the version response @type dict """ with contextlib.suppress(KeyError): self.serverVersion.emit(response["version"]) if ( self.__plugin.getPreferences("OllamaHeartbeatInterval") == 0 and not self.__serverResponding ): # implicit connectivity check success self.__serverResponding = True self.serverStateChanged.emit(True, "") def state(self): """ Public method to get the current client state. @return current client state @rtype OllamaClientState """ return self.__state def __getServerReply(self, endpoint, data=None, delete=False): """ Private method to send a request to the 'ollama' server and return a reply object. @param endpoint 'ollama' API endpoint to be contacted @type str @param data dictionary containing the data to send to the server (defaults to None) @type dict (optional) @param delete flag indicating to send a delete request (defaults to False) @type bool (optional) @return 'ollama' server reply @rtype QNetworkReply """ ollamaUrl = ( QUrl( "http://127.0.0.1:{0}/api/{1}".format( self.__plugin.getPreferences("OllamaLocalPort"), endpoint, ) ) if self.__localServer else QUrl( "{0}://{1}:{2}/api/{3}".format( self.__plugin.getPreferences("OllamaScheme"), self.__plugin.getPreferences("OllamaHost"), self.__plugin.getPreferences("OllamaPort"), endpoint, ) ) ) request = QNetworkRequest(ollamaUrl) if data is not None: request.setHeader( QNetworkRequest.KnownHeaders.ContentTypeHeader, "application/json" ) jsonData = json.dumps(data).encode("utf-8") if delete: reply = self.__networkManager.sendCustomRequest( request, b"DELETE", jsonData ) else: reply = self.__networkManager.post(request, jsonData) else: reply = self.__networkManager.get(request) reply.errorOccurred.connect(lambda error: self.__errorOccurred(error, reply)) return reply def __sendRequest(self, endpoint, data=None, processResponse=None): """ Private method to send a request to the 'ollama' server and handle its responses. @param endpoint 'ollama' API endpoint to be contacted @type str @param data dictionary containing the data to send to the server (defaults to None) @type dict (optional) @param processResponse function handling the received data (defaults to None) @type function (optional) """ self.__state = OllamaClientState.Requesting reply = self.__getServerReply(endpoint=endpoint, data=data) reply.finished.connect(lambda: self.__replyFinished(reply)) reply.readyRead.connect(lambda: self.__processData(reply, processResponse)) if endpoint == "pull": self.__pullReply = reply else: self.__replies.append(reply) def __replyFinished(self, reply): """ Private method to handle the finished signal of the reply. @param reply reference to the finished network reply object @type QNetworkReply """ self.__state = OllamaClientState.Finished if reply == self.__pullReply: self.__pullReply = None elif reply in self.__replies: self.__replies.remove(reply) reply.deleteLater() self.finished.emit() def __errorOccurred(self, errorCode, reply): """ Private method to handle a network error of the given reply. @param errorCode error code reported by the reply @type QNetworkReply.NetworkError @param reply reference to the network reply object @type QNetworkReply """ if errorCode == QNetworkReply.NetworkError.ConnectionRefusedError: self.serverStateChanged.emit(False, self.__serverNotRespondingMessage()) elif errorCode not in ( QNetworkReply.NetworkError.NoError, QNetworkReply.NetworkError.OperationCanceledError, ): self.errorOccurred.emit( self.tr("<p>A network error occurred.</p><p>Error: {0}</p>").format( reply.errorString() ) ) def __processData(self, reply, processResponse): """ Private method to receive data from the 'ollama' server and process it with a given processing function or method. @param reply reference to the network reply object @type QNetworkReply @param processResponse processing function @type function """ self.__state = OllamaClientState.Receiving buffer = bytes(reply.readAll()) if buffer: with contextlib.suppress(json.JSONDecodeError): data = json.loads(buffer) if data and processResponse: processResponse(data) def __sendSyncRequest(self, endpoint, data=None, delete=False): """ Private method to send a request to the 'ollama' server and handle its responses. @param endpoint 'ollama' API endpoint to be contacted @type str @param data dictionary containing the data to send to the server (defaults to None) @type dict (optional) @param delete flag indicating to send a delete request (defaults to False) @type bool (optional) @return tuple containing the data sent by the 'ollama' server and the HTTP status code @rtype tuple of (Any, int) """ self.__state = OllamaClientState.Requesting reply = self.__getServerReply(endpoint=endpoint, data=data, delete=delete) while not reply.isFinished(): QCoreApplication.processEvents() QThread.msleep(100) reply.deleteLater() self.__state = OllamaClientState.Finished statusCode = reply.attribute(QNetworkRequest.Attribute.HttpStatusCodeAttribute) if reply.error() == QNetworkReply.NetworkError.NoError: buffer = bytes(reply.readAll()) with contextlib.suppress(json.JSONDecodeError): data = json.loads(buffer) return data, statusCode return None, statusCode def __getHeartbeatUrl(self): """ Private method to get the current heartbeat URL. @return URL to be contacted by the heartbeat check @rtype str """ return ( "http://127.0.0.1:{0}".format( self.__plugin.getPreferences("OllamaLocalPort"), ) if self.__localServer else "{0}://{1}:{2}/".format( self.__plugin.getPreferences("OllamaScheme"), self.__plugin.getPreferences("OllamaHost"), self.__plugin.getPreferences("OllamaPort"), ) ) def heartbeat(self): """ Public method to check, if the 'ollama' server has started and is responsive. @return flag indicating a responsive 'ollama' server @rtype bool """ request = QNetworkRequest(QUrl(self.__getHeartbeatUrl())) reply = self.__networkManager.head(request) while not reply.isFinished(): QCoreApplication.processEvents() QThread.msleep(100) reply.deleteLater() return reply.error() == QNetworkReply.NetworkError.NoError @pyqtSlot() def __setHeartbeatTimer(self): """ Private slot to configure the heartbeat timer. """ interval = self.__plugin.getPreferences("OllamaHeartbeatInterval") if interval: self.__heartbeatTimer.setInterval(interval * 1000) # interval in ms self.__heartbeatTimer.start() else: self.__heartbeatTimer.stop() self.serverStateChanged.emit(True, "") @pyqtSlot() def __periodicHeartbeat(self): """ Private slot to do a periodic check of the 'ollama' server responsiveness. """ responding = self.heartbeat() if responding != self.__serverResponding: msg = "" if responding else self.__serverNotRespondingMessage() self.serverStateChanged.emit(responding, msg) self.__serverResponding = responding def __serverNotRespondingMessage(self): """ Private method to assemble and return a message for a non-responsive server. @return error message @rtype str """ return ( self.tr("<p>Error: The local server at <b>{0}</b> is not responding.</p>") if self.__localServer else self.tr( "<p>Error: The configured server at <b>{0}</b> is not" " responding.</p>" ) ).format(self.__getHeartbeatUrl())