10 import contextlib |
10 import contextlib |
11 import datetime |
11 import datetime |
12 import enum |
12 import enum |
13 import json |
13 import json |
14 |
14 |
15 from PyQt6.QtCore import pyqtSignal, QObject, QUrl |
15 from PyQt6.QtCore import ( |
16 from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkRequest, QNetworkReply |
16 QCoreApplication, |
|
17 QObject, |
|
18 QThread, |
|
19 QTimer, |
|
20 QUrl, |
|
21 pyqtSignal, |
|
22 pyqtSlot, |
|
23 ) |
|
24 from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkReply, QNetworkRequest |
17 |
25 |
18 from eric7.EricNetwork.EricNetworkProxyFactory import proxyAuthenticationRequired |
26 from eric7.EricNetwork.EricNetworkProxyFactory import proxyAuthenticationRequired |
19 |
27 |
20 |
28 |
21 class OllamaClientState(enum.Enum): |
29 class OllamaClientState(enum.Enum): |
22 """ |
30 """ |
23 Class defining the various client states. |
31 Class defining the various client states. |
24 """ |
32 """ |
|
33 |
25 Waiting = 0 |
34 Waiting = 0 |
26 Requesting = 1 |
35 Requesting = 1 |
27 Receiving = 2 |
36 Receiving = 2 |
28 Finished = 3 |
37 Finished = 3 |
29 |
38 |
30 |
39 |
31 class OllamaClient(QObject): |
40 class OllamaClient(QObject): |
32 """ |
41 """ |
33 Class implementing the 'ollama' client. |
42 Class implementing the 'ollama' client. |
34 |
43 |
35 @signal replyReceived(content:str, role:str) emitted after a response from the |
44 @signal replyReceived(content:str, role:str) emitted after a response from the |
36 'ollama' server was received |
45 'ollama' server was received |
37 @signal modelsList(modelNames:list[str]) emitted after the list of model |
46 @signal modelsList(modelNames:list[str]) emitted after the list of model |
38 names was obtained from the 'ollama' server |
47 names was obtained from the 'ollama' server |
39 @signal detailedModelsList(models:list[dict]) emitted after the list of |
48 @signal detailedModelsList(models:list[dict]) emitted after the list of |
40 models was obtained from the 'ollama' server giving some model details |
49 models was obtained from the 'ollama' server giving some model details |
41 @signal runningModelsList(models:list[dict]) emitted after the list of |
50 @signal runningModelsList(models:list[dict]) emitted after the list of |
42 running models was obtained from the 'ollama' server giving some model |
51 running models was obtained from the 'ollama' server giving some model |
43 execution details |
52 execution details |
44 @signal pullStatus(msg:str, id:str, total:int, completed:int) emitted to indicate |
53 @signal pullStatus(msg:str, id:str, total:int, completed:int) emitted to indicate |
45 the status of a pull request as reported by the 'ollama' server |
54 the status of a pull request as reported by the 'ollama' server |
|
55 @signal serverVersion(version:str) emitted after the server version was obtained |
|
56 from the 'ollama' server |
46 @signal finished() emitted to indicate the completion of a request |
57 @signal finished() emitted to indicate the completion of a request |
47 @signal errorOccurred(error:str) emitted to indicate a network error occurred |
58 @signal errorOccurred(error:str) emitted to indicate a network error occurred |
48 while processing the request |
59 while processing the request |
|
60 @signal serverStateChanged(ok:bool) emitted to indicate a change of the server |
|
61 responsiveness |
49 """ |
62 """ |
50 |
63 |
51 replyReceived = pyqtSignal(str, str) |
64 replyReceived = pyqtSignal(str, str) |
52 modelsList = pyqtSignal(list) |
65 modelsList = pyqtSignal(list) |
53 detailedModelsList = pyqtSignal(list) |
66 detailedModelsList = pyqtSignal(list) |
54 runningModelsList = pyqtSignal(list) |
67 runningModelsList = pyqtSignal(list) |
55 pullStatus = pyqtSignal(str, str, int, int) |
68 pullStatus = pyqtSignal(str, str, int, int) |
|
69 serverVersion = pyqtSignal(str) |
56 finished = pyqtSignal() |
70 finished = pyqtSignal() |
57 errorOccurred = pyqtSignal(str) |
71 errorOccurred = pyqtSignal(str) |
|
72 serverStateChanged = pyqtSignal(bool) |
58 |
73 |
59 def __init__(self, plugin, parent=None): |
74 def __init__(self, plugin, parent=None): |
60 """ |
75 """ |
61 Constructor |
76 Constructor |
62 |
77 |
73 self.__networkManager = QNetworkAccessManager(self) |
88 self.__networkManager = QNetworkAccessManager(self) |
74 self.__networkManager.proxyAuthenticationRequired.connect( |
89 self.__networkManager.proxyAuthenticationRequired.connect( |
75 proxyAuthenticationRequired |
90 proxyAuthenticationRequired |
76 ) |
91 ) |
77 |
92 |
|
93 self.__serverResponding = False |
|
94 self.__heartbeatTimer = QTimer(self) |
|
95 self.__heartbeatTimer.timeout.connect(self.__periodicHeartbeat) |
|
96 |
78 self.__state = OllamaClientState.Waiting |
97 self.__state = OllamaClientState.Waiting |
|
98 |
|
99 self.__serverResponding = False # start with a faulty state |
|
100 |
|
101 self.__plugin.preferencesChanged.connect(self.__setHeartbeatTimer) |
|
102 self.__setHeartbeatTimer() |
79 |
103 |
80 def chat(self, model, messages): |
104 def chat(self, model, messages): |
81 """ |
105 """ |
82 Public method to request a chat completion from the 'ollama' server. |
106 Public method to request a chat completion from the 'ollama' server. |
83 |
107 |
201 models = [] |
224 models = [] |
202 with contextlib.suppress(KeyError): |
225 with contextlib.suppress(KeyError): |
203 for model in response["models"]: |
226 for model in response["models"]: |
204 name = model["name"] |
227 name = model["name"] |
205 if name: |
228 if name: |
206 models.append(name) |
229 models.append(name.replace(":latest", "")) |
207 self.modelsList.emit(models) |
230 self.modelsList.emit(models) |
208 |
231 |
209 def listDetails(self): |
232 def listDetails(self): |
210 """ |
233 """ |
211 Public method to request a list of models available locally from the 'ollama' |
234 Public method to request a list of models available locally from the 'ollama' |
212 server with some model details. |
235 server with some model details. |
213 |
|
214 @return list of available models |
|
215 @rtype list of dict |
|
216 """ |
236 """ |
217 # TODO: not implemented yet |
237 # TODO: not implemented yet |
218 self.__sendRequest("tags", processResponse=self.__processDetailedModelsList) |
238 self.__sendRequest("tags", processResponse=self.__processDetailedModelsList) |
219 |
239 |
220 def __processDetailedModelsList(self, response): |
240 def __processDetailedModelsList(self, response): |
324 if data is not None: |
360 if data is not None: |
325 request.setHeader( |
361 request.setHeader( |
326 QNetworkRequest.KnownHeaders.ContentTypeHeader, "application/json" |
362 QNetworkRequest.KnownHeaders.ContentTypeHeader, "application/json" |
327 ) |
363 ) |
328 jsonData = json.dumps(data).encode("utf-8") |
364 jsonData = json.dumps(data).encode("utf-8") |
329 reply = self.__networkManager.post(request=request, data=jsonData) |
365 reply = self.__networkManager.post(request, jsonData) |
330 else: |
366 else: |
331 reply = self.__networkManager.get(request=request) |
367 reply = self.__networkManager.get(request) |
332 |
368 |
333 reply.finished.connect(lambda: self.__replyFinished(reply)) |
369 reply.finished.connect(lambda: self.__replyFinished(reply)) |
334 reply.errorOccurred.connect(lambda error: self.__errorOccurred(error, reply)) |
370 reply.errorOccurred.connect(lambda error: self.__errorOccurred(error, reply)) |
335 reply.readyRead.connect(lambda: self.__processData(reply, processResponse)) |
371 reply.readyRead.connect(lambda: self.__processData(reply, processResponse)) |
336 self.__replies.append(reply) |
372 self.__replies.append(reply) |
381 if buffer: |
417 if buffer: |
382 with contextlib.suppress(json.JSONDecodeError): |
418 with contextlib.suppress(json.JSONDecodeError): |
383 data = json.loads(buffer) |
419 data = json.loads(buffer) |
384 if data and processResponse: |
420 if data and processResponse: |
385 processResponse(data) |
421 processResponse(data) |
|
422 |
|
423 def heartbeat(self): |
|
424 """ |
|
425 Public method to check, if the 'ollama' server has started and is responsive. |
|
426 |
|
427 @return flag indicating a responsive 'ollama' server |
|
428 @rtype bool |
|
429 """ |
|
430 ollamaUrl = QUrl( |
|
431 "{0}://{1}:{2}/".format( |
|
432 self.__plugin.getPreferences("OllamaScheme"), |
|
433 self.__plugin.getPreferences("OllamaHost"), |
|
434 self.__plugin.getPreferences("OllamaPort"), |
|
435 ) |
|
436 ) |
|
437 request = QNetworkRequest(ollamaUrl) |
|
438 reply = self.__networkManager.head(request) |
|
439 while not reply.isFinished(): |
|
440 QCoreApplication.processEvents() |
|
441 QThread.msleep(100) |
|
442 |
|
443 reply.deleteLater() |
|
444 |
|
445 return reply.error() == QNetworkReply.NetworkError.NoError |
|
446 |
|
447 @pyqtSlot() |
|
448 def __setHeartbeatTimer(self): |
|
449 """ |
|
450 Private slot to configure the heartbeat timer. |
|
451 """ |
|
452 interval = self.__plugin.getPreferences("OllamaHeartbeatInterval") |
|
453 if interval: |
|
454 self.__heartbeatTimer.setInterval(interval * 1000) # interval in ms |
|
455 self.__heartbeatTimer.start() |
|
456 else: |
|
457 self.__heartbeatTimer.stop() |
|
458 |
|
459 @pyqtSlot() |
|
460 def __periodicHeartbeat(self): |
|
461 """ |
|
462 Private slot to do a periodic check of the 'ollama' server responsiveness. |
|
463 """ |
|
464 responding = self.heartbeat() |
|
465 if responding != self.__serverResponding: |
|
466 self.serverStateChanged.emit(responding) |
|
467 self.__serverResponding = responding |