|
1 # -*- coding: utf-8 -*- |
|
2 |
|
3 # Copyright (c) 2024 Detlev Offenbach <detlev@die-offenbachs.de> |
|
4 # |
|
5 |
|
6 """ |
|
7 Module implementing the 'ollama' client. |
|
8 """ |
|
9 |
|
10 import contextlib |
|
11 import datetime |
|
12 import enum |
|
13 import json |
|
14 |
|
15 from PyQt6.QtCore import pyqtSignal, QObject, QUrl |
|
16 from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkRequest, QNetworkReply |
|
17 |
|
18 from eric7.EricNetwork.EricNetworkProxyFactory import proxyAuthenticationRequired |
|
19 |
|
20 |
|
21 class OllamaClientState(enum.Enum): |
|
22 """ |
|
23 Class defining the various client states. |
|
24 """ |
|
25 Waiting = 0 |
|
26 Requesting = 1 |
|
27 Receiving = 2 |
|
28 Finished = 3 |
|
29 |
|
30 |
|
31 class OllamaClient(QObject): |
|
32 """ |
|
33 Class implementing the 'ollama' client. |
|
34 |
|
35 @signal replyReceived(content:str, role:str) emitted after a response from the |
|
36 'ollama' server was received |
|
37 @signal modelsList(modelNames:list[str]) emitted after the list of model |
|
38 names was obtained from the 'ollama' server |
|
39 @signal detailedModelsList(models:list[dict]) emitted after the list of |
|
40 models was obtained from the 'ollama' server giving some model details |
|
41 @signal runningModelsList(models:list[dict]) emitted after the list of |
|
42 running models was obtained from the 'ollama' server giving some model |
|
43 execution details |
|
44 @signal pullStatus(msg:str, id:str, total:int, completed:int) emitted to indicate |
|
45 the status of a pull request as reported by the 'ollama' server |
|
46 @signal finished() emitted to indicate the completion of a request |
|
47 @signal errorOccurred(error:str) emitted to indicate a network error occurred |
|
48 while processing the request |
|
49 """ |
|
50 |
|
51 replyReceived = pyqtSignal(str, str) |
|
52 modelsList = pyqtSignal(list) |
|
53 detailedModelsList = pyqtSignal(list) |
|
54 runningModelsList = pyqtSignal(list) |
|
55 pullStatus = pyqtSignal(str, str, int, int) |
|
56 finished = pyqtSignal() |
|
57 errorOccurred = pyqtSignal(str) |
|
58 |
|
59 def __init__(self, plugin, parent=None): |
|
60 """ |
|
61 Constructor |
|
62 |
|
63 @param plugin reference to the plugin object |
|
64 @type PluginOllamaInterface |
|
65 @param parent reference to the parent object (defaults to None) |
|
66 @type QObject (optional) |
|
67 """ |
|
68 super().__init__(parent=parent) |
|
69 |
|
70 self.__plugin = plugin |
|
71 self.__replies = [] |
|
72 |
|
73 self.__networkManager = QNetworkAccessManager(self) |
|
74 self.__networkManager.proxyAuthenticationRequired.connect( |
|
75 proxyAuthenticationRequired |
|
76 ) |
|
77 |
|
78 self.__state = OllamaClientState.Waiting |
|
79 |
|
80 def chat(self, model, messages): |
|
81 """ |
|
82 Public method to request a chat completion from the 'ollama' server. |
|
83 |
|
84 @param model name of the model to be used |
|
85 @type str |
|
86 @param messages list of message objects |
|
87 @type list of dict |
|
88 """ |
|
89 # TODO: not implemented yet |
|
90 ollamaRequest = { |
|
91 "model": model, |
|
92 "messages": messages, |
|
93 } |
|
94 self.__sendRequest( |
|
95 "chat", data=ollamaRequest, processResponse=self.__processChatResponse |
|
96 ) |
|
97 |
|
98 def __processChatResponse(self, response): |
|
99 """ |
|
100 Private method to process the chat response of the 'ollama' server. |
|
101 |
|
102 @param response dictionary containing the chat response |
|
103 @type dict |
|
104 """ |
|
105 with contextlib.suppress(KeyError): |
|
106 message = response["message"] |
|
107 if message: |
|
108 self.replyReceived.emit(message["content"], message["role"]) |
|
109 |
|
110 def generate(self, model, prompt, suffix=None): |
|
111 """ |
|
112 Public method to request to generate a completion from the 'ollama' server. |
|
113 |
|
114 @param model name of the model to be used |
|
115 @type str |
|
116 @param prompt prompt to generate a response for |
|
117 @type str |
|
118 @param suffix text after the model response (defaults to None) |
|
119 @type str (optional) |
|
120 """ |
|
121 # TODO: not implemented yet |
|
122 ollamaRequest = { |
|
123 "model": model, |
|
124 "prompt": prompt, |
|
125 } |
|
126 if suffix is not None: |
|
127 ollamaRequest["suffix"] = suffix |
|
128 self.__sendRequest( |
|
129 "generate", |
|
130 data=ollamaRequest, |
|
131 processResponse=self.__processGenerateResponse, |
|
132 ) |
|
133 |
|
134 def __processGenerateResponse(self, response): |
|
135 """ |
|
136 Private method to process the generate response of the 'ollama' server. |
|
137 |
|
138 @param response dictionary containing the generate response |
|
139 @type dict |
|
140 """ |
|
141 with contextlib.suppress(KeyError): |
|
142 self.replyReceived.emit(response["response"], "") |
|
143 |
|
144 def pull(self, model): |
|
145 """ |
|
146 Public method to ask the 'ollama' server to pull the given model. |
|
147 |
|
148 @param model name of the model |
|
149 @type str |
|
150 """ |
|
151 # TODO: not implemented yet |
|
152 ollamaRequest = { |
|
153 "name": model, |
|
154 } |
|
155 self.__sendRequest( |
|
156 "pull", data=ollamaRequest, processResponse=self.__processPullResponse |
|
157 ) |
|
158 |
|
159 def __processPullResponse(self, response): |
|
160 """ |
|
161 Private method to process a pull response of the 'ollama' server. |
|
162 |
|
163 @param response dictionary containing the pull response |
|
164 @type dict |
|
165 """ |
|
166 with contextlib.suppress(KeyError): |
|
167 status = response["status"] |
|
168 idStr = response.get("digest", "")[:20] |
|
169 total = response.get("total", 0) |
|
170 completed = response.get("completed", 0) |
|
171 self.pullStatus.emit(status, idStr, total, completed) |
|
172 |
|
173 def remove(self, model): |
|
174 """ |
|
175 Public method to ask the 'ollama' server to delete the given model. |
|
176 |
|
177 @param model name of the model |
|
178 @type str |
|
179 """ |
|
180 # TODO: not implemented yet |
|
181 ollamaRequest = { |
|
182 "name": model, |
|
183 } |
|
184 self.__sendRequest("delete", data=ollamaRequest) |
|
185 |
|
186 def list(self): |
|
187 """ |
|
188 Public method to request a list of models available locally from the 'ollama' |
|
189 server. |
|
190 """ |
|
191 # TODO: not implemented yet |
|
192 self.__sendRequest("tags", processResponse=self.__processModelsList) |
|
193 |
|
194 def __processModelsList(self, response): |
|
195 """ |
|
196 Private method to process the tags response of the 'ollama' server. |
|
197 |
|
198 @param response dictionary containing the tags response |
|
199 @type dict |
|
200 """ |
|
201 models = [] |
|
202 with contextlib.suppress(KeyError): |
|
203 for model in response["models"]: |
|
204 name = model["name"] |
|
205 if name: |
|
206 models.append(name) |
|
207 self.modelsList.emit(models) |
|
208 |
|
209 def listDetails(self): |
|
210 """ |
|
211 Public method to request a list of models available locally from the 'ollama' |
|
212 server with some model details. |
|
213 |
|
214 @return list of available models |
|
215 @rtype list of dict |
|
216 """ |
|
217 # TODO: not implemented yet |
|
218 self.__sendRequest("tags", processResponse=self.__processDetailedModelsList) |
|
219 |
|
220 def __processDetailedModelsList(self, response): |
|
221 """ |
|
222 Private method to process the tags response of the 'ollama' server extracting |
|
223 some model details. |
|
224 |
|
225 @param response dictionary containing the tags response |
|
226 @type dict |
|
227 """ |
|
228 models = [] |
|
229 with contextlib.suppress(KeyError): |
|
230 for model in response["models"]: |
|
231 name = model["name"] |
|
232 if name: |
|
233 models.append( |
|
234 { |
|
235 "name": name, |
|
236 "id": model["digest"][:20], # first 20 characters only |
|
237 "size": model["size"], |
|
238 "modified": datetime.datetime.fromisoformat( |
|
239 model["modified_at"] |
|
240 ), |
|
241 } |
|
242 ) |
|
243 self.detailedModelsList.emit(models) |
|
244 |
|
245 def listRunning(self): |
|
246 """ |
|
247 Public method to request a list of running models from the 'ollama' server |
|
248 """ |
|
249 # TODO: not implemented yet |
|
250 self.__sendRequest("ps", processResponse=self.__processRunningModelsList) |
|
251 |
|
252 def __processRunningModelsList(self, response): |
|
253 """ |
|
254 Private method to process the ps response of the 'ollama' server extracting |
|
255 some model execution details. |
|
256 |
|
257 @param response dictionary containing the ps response |
|
258 @type dict |
|
259 """ |
|
260 models = [] |
|
261 with contextlib.suppress(KeyError): |
|
262 for model in response["models"]: |
|
263 name = model["name"] |
|
264 if name: |
|
265 if model["size_vram"] == 0: |
|
266 processor = self.tr("100% CPU") |
|
267 elif model["size_vram"] == model["size"]: |
|
268 processor = self.tr("100% GPU") |
|
269 elif model["size_vram"] > model["size_"] or model["size"] == 0: |
|
270 processor = self.tr("unknown") |
|
271 else: |
|
272 sizeCpu = model["size"] - model["size_vram"] |
|
273 cpuPercent = round(sizeCpu / model["size_vram"] * 100) |
|
274 processor = self.tr("{0}% / {1}% CPU / GPU").format( |
|
275 cpuPercent, 100 - cpuPercent |
|
276 ) |
|
277 models.append( |
|
278 { |
|
279 "name": name, |
|
280 "id": model["digest"][:20], # first 20 characters only |
|
281 "size": model["size"], |
|
282 "size_vram": model["size_vram"], |
|
283 "processor": processor, |
|
284 "expires": datetime.datetime.fromisoformat( |
|
285 model["expires_at"] |
|
286 ), |
|
287 } |
|
288 ) |
|
289 self.runningModelsList.emit(models) |
|
290 |
|
291 def state(self): |
|
292 """ |
|
293 Public method to get the current client state. |
|
294 |
|
295 @return current client state |
|
296 @rtype OllamaClientState |
|
297 """ |
|
298 return self.__state |
|
299 |
|
300 def __sendRequest(self, endpoint, data=None, processResponse=None): |
|
301 """ |
|
302 Private method to send a request to the 'ollama' server and handle its |
|
303 responses. |
|
304 |
|
305 @param endpoint 'ollama' API endpoint to be contacted |
|
306 @type str |
|
307 @param data dictionary containing the data to send to the server |
|
308 (defaults to None) |
|
309 @type dict (optional) |
|
310 @param processResponse function handling the received data (defaults to None) |
|
311 @type function (optional) |
|
312 """ |
|
313 self.__state = OllamaClientState.Requesting |
|
314 |
|
315 ollamaUrl = QUrl( |
|
316 "{0}://{1}:{2}/api/{3}".format( |
|
317 self.__plugin.getPreferences("OllamaScheme"), |
|
318 self.__plugin.getPreferences("OllamaHost"), |
|
319 self.__plugin.getPreferences("OllamaPort"), |
|
320 endpoint, |
|
321 ) |
|
322 ) |
|
323 request = QNetworkRequest(ollamaUrl) |
|
324 if data is not None: |
|
325 request.setHeader( |
|
326 QNetworkRequest.KnownHeaders.ContentTypeHeader, "application/json" |
|
327 ) |
|
328 jsonData = json.dumps(data).encode("utf-8") |
|
329 reply = self.__networkManager.post(request=request, data=jsonData) |
|
330 else: |
|
331 reply = self.__networkManager.get(request=request) |
|
332 |
|
333 reply.finished.connect(lambda: self.__replyFinished(reply)) |
|
334 reply.errorOccurred.connect(lambda error: self.__errorOccurred(error, reply)) |
|
335 reply.readyRead.connect(lambda: self.__processData(reply, processResponse)) |
|
336 self.__replies.append(reply) |
|
337 |
|
338 def __replyFinished(self, reply): |
|
339 """ |
|
340 Private method to handle the finished signal of the reply. |
|
341 |
|
342 @param reply reference to the finished network reply object |
|
343 @type QNetworkReply |
|
344 """ |
|
345 self.__state = OllamaClientState.Finished |
|
346 |
|
347 if reply in self.__replies: |
|
348 self.__replies.remove(reply) |
|
349 |
|
350 reply.deleteLater() |
|
351 |
|
352 def __errorOccurred(self, errorCode, reply): |
|
353 """ |
|
354 Private method to handle a network error of the given reply. |
|
355 |
|
356 @param errorCode error code reported by the reply |
|
357 @type QNetworkReply.NetworkError |
|
358 @param reply reference to the network reply object |
|
359 @type QNetworkReply |
|
360 """ |
|
361 if errorCode != QNetworkReply.NetworkError.NoError: |
|
362 self.errorOccurred.emit( |
|
363 self.tr("<p>A network error occurred.</p><p>Error: {0}</p>").format( |
|
364 reply.errorString() |
|
365 ) |
|
366 ) |
|
367 |
|
368 def __processData(self, reply, processResponse): |
|
369 """ |
|
370 Private method to receive data from the 'ollama' server and process it with a |
|
371 given processing function or method. |
|
372 |
|
373 @param reply reference to the network reply object |
|
374 @type QNetworkReply |
|
375 @param processResponse processing function |
|
376 @type function |
|
377 """ |
|
378 self.__state = OllamaClientState.Receiving |
|
379 |
|
380 buffer = bytes(reply.readAll()) |
|
381 if buffer: |
|
382 with contextlib.suppress(json.JSONDecodeError): |
|
383 data = json.loads(buffer) |
|
384 if data and processResponse: |
|
385 processResponse(data) |