OllamaInterface/OllamaClient.py

changeset 3
ca28466a186d
child 4
7dd1b9cd3150
equal deleted inserted replaced
2:fee250704d3d 3:ca28466a186d
1 # -*- coding: utf-8 -*-
2
3 # Copyright (c) 2024 Detlev Offenbach <detlev@die-offenbachs.de>
4 #
5
6 """
7 Module implementing the 'ollama' client.
8 """
9
10 import contextlib
11 import datetime
12 import enum
13 import json
14
15 from PyQt6.QtCore import pyqtSignal, QObject, QUrl
16 from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkRequest, QNetworkReply
17
18 from eric7.EricNetwork.EricNetworkProxyFactory import proxyAuthenticationRequired
19
20
21 class OllamaClientState(enum.Enum):
22 """
23 Class defining the various client states.
24 """
25 Waiting = 0
26 Requesting = 1
27 Receiving = 2
28 Finished = 3
29
30
31 class OllamaClient(QObject):
32 """
33 Class implementing the 'ollama' client.
34
35 @signal replyReceived(content:str, role:str) emitted after a response from the
36 'ollama' server was received
37 @signal modelsList(modelNames:list[str]) emitted after the list of model
38 names was obtained from the 'ollama' server
39 @signal detailedModelsList(models:list[dict]) emitted after the list of
40 models was obtained from the 'ollama' server giving some model details
41 @signal runningModelsList(models:list[dict]) emitted after the list of
42 running models was obtained from the 'ollama' server giving some model
43 execution details
44 @signal pullStatus(msg:str, id:str, total:int, completed:int) emitted to indicate
45 the status of a pull request as reported by the 'ollama' server
46 @signal finished() emitted to indicate the completion of a request
47 @signal errorOccurred(error:str) emitted to indicate a network error occurred
48 while processing the request
49 """
50
51 replyReceived = pyqtSignal(str, str)
52 modelsList = pyqtSignal(list)
53 detailedModelsList = pyqtSignal(list)
54 runningModelsList = pyqtSignal(list)
55 pullStatus = pyqtSignal(str, str, int, int)
56 finished = pyqtSignal()
57 errorOccurred = pyqtSignal(str)
58
59 def __init__(self, plugin, parent=None):
60 """
61 Constructor
62
63 @param plugin reference to the plugin object
64 @type PluginOllamaInterface
65 @param parent reference to the parent object (defaults to None)
66 @type QObject (optional)
67 """
68 super().__init__(parent=parent)
69
70 self.__plugin = plugin
71 self.__replies = []
72
73 self.__networkManager = QNetworkAccessManager(self)
74 self.__networkManager.proxyAuthenticationRequired.connect(
75 proxyAuthenticationRequired
76 )
77
78 self.__state = OllamaClientState.Waiting
79
80 def chat(self, model, messages):
81 """
82 Public method to request a chat completion from the 'ollama' server.
83
84 @param model name of the model to be used
85 @type str
86 @param messages list of message objects
87 @type list of dict
88 """
89 # TODO: not implemented yet
90 ollamaRequest = {
91 "model": model,
92 "messages": messages,
93 }
94 self.__sendRequest(
95 "chat", data=ollamaRequest, processResponse=self.__processChatResponse
96 )
97
98 def __processChatResponse(self, response):
99 """
100 Private method to process the chat response of the 'ollama' server.
101
102 @param response dictionary containing the chat response
103 @type dict
104 """
105 with contextlib.suppress(KeyError):
106 message = response["message"]
107 if message:
108 self.replyReceived.emit(message["content"], message["role"])
109
110 def generate(self, model, prompt, suffix=None):
111 """
112 Public method to request to generate a completion from the 'ollama' server.
113
114 @param model name of the model to be used
115 @type str
116 @param prompt prompt to generate a response for
117 @type str
118 @param suffix text after the model response (defaults to None)
119 @type str (optional)
120 """
121 # TODO: not implemented yet
122 ollamaRequest = {
123 "model": model,
124 "prompt": prompt,
125 }
126 if suffix is not None:
127 ollamaRequest["suffix"] = suffix
128 self.__sendRequest(
129 "generate",
130 data=ollamaRequest,
131 processResponse=self.__processGenerateResponse,
132 )
133
134 def __processGenerateResponse(self, response):
135 """
136 Private method to process the generate response of the 'ollama' server.
137
138 @param response dictionary containing the generate response
139 @type dict
140 """
141 with contextlib.suppress(KeyError):
142 self.replyReceived.emit(response["response"], "")
143
144 def pull(self, model):
145 """
146 Public method to ask the 'ollama' server to pull the given model.
147
148 @param model name of the model
149 @type str
150 """
151 # TODO: not implemented yet
152 ollamaRequest = {
153 "name": model,
154 }
155 self.__sendRequest(
156 "pull", data=ollamaRequest, processResponse=self.__processPullResponse
157 )
158
159 def __processPullResponse(self, response):
160 """
161 Private method to process a pull response of the 'ollama' server.
162
163 @param response dictionary containing the pull response
164 @type dict
165 """
166 with contextlib.suppress(KeyError):
167 status = response["status"]
168 idStr = response.get("digest", "")[:20]
169 total = response.get("total", 0)
170 completed = response.get("completed", 0)
171 self.pullStatus.emit(status, idStr, total, completed)
172
173 def remove(self, model):
174 """
175 Public method to ask the 'ollama' server to delete the given model.
176
177 @param model name of the model
178 @type str
179 """
180 # TODO: not implemented yet
181 ollamaRequest = {
182 "name": model,
183 }
184 self.__sendRequest("delete", data=ollamaRequest)
185
186 def list(self):
187 """
188 Public method to request a list of models available locally from the 'ollama'
189 server.
190 """
191 # TODO: not implemented yet
192 self.__sendRequest("tags", processResponse=self.__processModelsList)
193
194 def __processModelsList(self, response):
195 """
196 Private method to process the tags response of the 'ollama' server.
197
198 @param response dictionary containing the tags response
199 @type dict
200 """
201 models = []
202 with contextlib.suppress(KeyError):
203 for model in response["models"]:
204 name = model["name"]
205 if name:
206 models.append(name)
207 self.modelsList.emit(models)
208
209 def listDetails(self):
210 """
211 Public method to request a list of models available locally from the 'ollama'
212 server with some model details.
213
214 @return list of available models
215 @rtype list of dict
216 """
217 # TODO: not implemented yet
218 self.__sendRequest("tags", processResponse=self.__processDetailedModelsList)
219
220 def __processDetailedModelsList(self, response):
221 """
222 Private method to process the tags response of the 'ollama' server extracting
223 some model details.
224
225 @param response dictionary containing the tags response
226 @type dict
227 """
228 models = []
229 with contextlib.suppress(KeyError):
230 for model in response["models"]:
231 name = model["name"]
232 if name:
233 models.append(
234 {
235 "name": name,
236 "id": model["digest"][:20], # first 20 characters only
237 "size": model["size"],
238 "modified": datetime.datetime.fromisoformat(
239 model["modified_at"]
240 ),
241 }
242 )
243 self.detailedModelsList.emit(models)
244
245 def listRunning(self):
246 """
247 Public method to request a list of running models from the 'ollama' server
248 """
249 # TODO: not implemented yet
250 self.__sendRequest("ps", processResponse=self.__processRunningModelsList)
251
252 def __processRunningModelsList(self, response):
253 """
254 Private method to process the ps response of the 'ollama' server extracting
255 some model execution details.
256
257 @param response dictionary containing the ps response
258 @type dict
259 """
260 models = []
261 with contextlib.suppress(KeyError):
262 for model in response["models"]:
263 name = model["name"]
264 if name:
265 if model["size_vram"] == 0:
266 processor = self.tr("100% CPU")
267 elif model["size_vram"] == model["size"]:
268 processor = self.tr("100% GPU")
269 elif model["size_vram"] > model["size_"] or model["size"] == 0:
270 processor = self.tr("unknown")
271 else:
272 sizeCpu = model["size"] - model["size_vram"]
273 cpuPercent = round(sizeCpu / model["size_vram"] * 100)
274 processor = self.tr("{0}% / {1}% CPU / GPU").format(
275 cpuPercent, 100 - cpuPercent
276 )
277 models.append(
278 {
279 "name": name,
280 "id": model["digest"][:20], # first 20 characters only
281 "size": model["size"],
282 "size_vram": model["size_vram"],
283 "processor": processor,
284 "expires": datetime.datetime.fromisoformat(
285 model["expires_at"]
286 ),
287 }
288 )
289 self.runningModelsList.emit(models)
290
291 def state(self):
292 """
293 Public method to get the current client state.
294
295 @return current client state
296 @rtype OllamaClientState
297 """
298 return self.__state
299
300 def __sendRequest(self, endpoint, data=None, processResponse=None):
301 """
302 Private method to send a request to the 'ollama' server and handle its
303 responses.
304
305 @param endpoint 'ollama' API endpoint to be contacted
306 @type str
307 @param data dictionary containing the data to send to the server
308 (defaults to None)
309 @type dict (optional)
310 @param processResponse function handling the received data (defaults to None)
311 @type function (optional)
312 """
313 self.__state = OllamaClientState.Requesting
314
315 ollamaUrl = QUrl(
316 "{0}://{1}:{2}/api/{3}".format(
317 self.__plugin.getPreferences("OllamaScheme"),
318 self.__plugin.getPreferences("OllamaHost"),
319 self.__plugin.getPreferences("OllamaPort"),
320 endpoint,
321 )
322 )
323 request = QNetworkRequest(ollamaUrl)
324 if data is not None:
325 request.setHeader(
326 QNetworkRequest.KnownHeaders.ContentTypeHeader, "application/json"
327 )
328 jsonData = json.dumps(data).encode("utf-8")
329 reply = self.__networkManager.post(request=request, data=jsonData)
330 else:
331 reply = self.__networkManager.get(request=request)
332
333 reply.finished.connect(lambda: self.__replyFinished(reply))
334 reply.errorOccurred.connect(lambda error: self.__errorOccurred(error, reply))
335 reply.readyRead.connect(lambda: self.__processData(reply, processResponse))
336 self.__replies.append(reply)
337
338 def __replyFinished(self, reply):
339 """
340 Private method to handle the finished signal of the reply.
341
342 @param reply reference to the finished network reply object
343 @type QNetworkReply
344 """
345 self.__state = OllamaClientState.Finished
346
347 if reply in self.__replies:
348 self.__replies.remove(reply)
349
350 reply.deleteLater()
351
352 def __errorOccurred(self, errorCode, reply):
353 """
354 Private method to handle a network error of the given reply.
355
356 @param errorCode error code reported by the reply
357 @type QNetworkReply.NetworkError
358 @param reply reference to the network reply object
359 @type QNetworkReply
360 """
361 if errorCode != QNetworkReply.NetworkError.NoError:
362 self.errorOccurred.emit(
363 self.tr("<p>A network error occurred.</p><p>Error: {0}</p>").format(
364 reply.errorString()
365 )
366 )
367
368 def __processData(self, reply, processResponse):
369 """
370 Private method to receive data from the 'ollama' server and process it with a
371 given processing function or method.
372
373 @param reply reference to the network reply object
374 @type QNetworkReply
375 @param processResponse processing function
376 @type function
377 """
378 self.__state = OllamaClientState.Receiving
379
380 buffer = bytes(reply.readAll())
381 if buffer:
382 with contextlib.suppress(json.JSONDecodeError):
383 data = json.loads(buffer)
384 if data and processResponse:
385 processResponse(data)

eric ide

mercurial