OllamaInterface/OllamaClient.py

changeset 11
3641ea6b55d5
parent 9
c471738b75b3
child 13
3fd49d7004b2
--- a/OllamaInterface/OllamaClient.py	Tue Aug 27 09:19:39 2024 +0200
+++ b/OllamaInterface/OllamaClient.py	Tue Aug 27 14:06:50 2024 +0200
@@ -47,6 +47,7 @@
         names was obtained from the 'ollama' server
     @signal pullStatus(msg:str, id:str, total:int, completed:int) emitted to indicate
         the status of a pull request as reported by the 'ollama' server
+    @signal pullError(msg:str) emitted to indicate an error during a pull operation
     @signal serverVersion(version:str) emitted after the server version was obtained
         from the 'ollama' server
     @signal finished() emitted to indicate the completion of a request
@@ -58,7 +59,8 @@
 
     replyReceived = pyqtSignal(str, str, bool)
     modelsList = pyqtSignal(list)
-    pullStatus = pyqtSignal(str, str, int, int)
+    pullStatus = pyqtSignal(str, str, "unsigned long int", "unsigned long int")
+    pullError = pyqtSignal(str)
     serverVersion = pyqtSignal(str)
     finished = pyqtSignal()
     errorOccurred = pyqtSignal(str)
@@ -77,6 +79,7 @@
 
         self.__plugin = plugin
         self.__replies = []
+        self.__pullReply = None
 
         self.__networkManager = QNetworkAccessManager(self)
         self.__networkManager.proxyAuthenticationRequired.connect(
@@ -178,9 +181,8 @@
         @param model name of the model
         @type str
         """
-        # TODO: not implemented yet
         ollamaRequest = {
-            "name": model,
+            "model": model,
         }
         self.__sendRequest(
             "pull", data=ollamaRequest, processResponse=self.__processPullResponse
@@ -193,12 +195,22 @@
         @param response dictionary containing the pull response
         @type dict
         """
-        with contextlib.suppress(KeyError):
-            status = response["status"]
-            idStr = response.get("digest", "")[:20]
-            total = response.get("total", 0)
-            completed = response.get("completed", 0)
-            self.pullStatus.emit(status, idStr, total, completed)
+        if "error" in response:
+            self.pullError.emit(response["error"])
+        else:
+            with contextlib.suppress(KeyError):
+                status = response["status"]
+                idStr = response.get("digest", "")[:20]
+                total = response.get("total", 0)
+                completed = response.get("completed", 0)
+                self.pullStatus.emit(status, idStr, total, completed)
+
+    def abortPull(self):
+        """
+        Public method to abort an ongoing pull operation.
+        """
+        if self.__pullReply is not None:
+            self.__pullReply.close()
 
     def remove(self, model):
         """
@@ -399,7 +411,10 @@
         reply = self.__getServerReply(endpoint=endpoint, data=data)
         reply.finished.connect(lambda: self.__replyFinished(reply))
         reply.readyRead.connect(lambda: self.__processData(reply, processResponse))
-        self.__replies.append(reply)
+        if endpoint == "pull":
+            self.__pullReply = reply
+        else:
+            self.__replies.append(reply)
 
     def __replyFinished(self, reply):
         """
@@ -410,11 +425,15 @@
         """
         self.__state = OllamaClientState.Finished
 
-        if reply in self.__replies:
+        if reply == self.__pullReply:
+            self.__pullReply = None
+        elif reply in self.__replies:
             self.__replies.remove(reply)
 
         reply.deleteLater()
 
+        self.finished.emit()
+
     def __errorOccurred(self, errorCode, reply):
         """
         Private method to handle a network error of the given reply.
@@ -424,7 +443,10 @@
         @param reply reference to the network reply object
         @type QNetworkReply
         """
-        if errorCode != QNetworkReply.NetworkError.NoError:
+        if errorCode not in (
+            QNetworkReply.NetworkError.NoError,
+            QNetworkReply.NetworkError.OperationCanceledError,
+        ):
             self.errorOccurred.emit(
                 self.tr("<p>A network error occurred.</p><p>Error: {0}</p>").format(
                     reply.errorString()

eric ide

mercurial