Ollama4j
A Java library (wrapper/binding) for Ollama server.
Loading...
Searching...
No Matches
Ollama.java
Go to the documentation of this file.
1/*
2 * Ollama4j - Java library for interacting with Ollama server.
3 * Copyright (c) 2025 Amith Koujalgi and contributors.
4 *
5 * Licensed under the MIT License (the "License");
6 * you may not use this file except in compliance with the License.
7 *
8*/
9package io.github.ollama4j;
10
11import com.fasterxml.jackson.databind.ObjectMapper;
12import io.github.ollama4j.exceptions.OllamaException;
13import io.github.ollama4j.exceptions.RoleNotFoundException;
14import io.github.ollama4j.exceptions.ToolInvocationException;
15import io.github.ollama4j.metrics.MetricsRecorder;
16import io.github.ollama4j.models.chat.*;
17import io.github.ollama4j.models.chat.OllamaChatTokenHandler;
18import io.github.ollama4j.models.embed.OllamaEmbedRequest;
19import io.github.ollama4j.models.embed.OllamaEmbedResult;
20import io.github.ollama4j.models.generate.OllamaGenerateRequest;
21import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver;
22import io.github.ollama4j.models.generate.OllamaGenerateTokenHandler;
23import io.github.ollama4j.models.ps.ModelProcessesResult;
24import io.github.ollama4j.models.request.*;
25import io.github.ollama4j.models.response.*;
26import io.github.ollama4j.tools.*;
27import io.github.ollama4j.tools.annotations.OllamaToolService;
28import io.github.ollama4j.tools.annotations.ToolProperty;
29import io.github.ollama4j.tools.annotations.ToolSpec;
30import io.github.ollama4j.utils.Constants;
31import io.github.ollama4j.utils.Utils;
32import java.io.*;
33import java.lang.reflect.InvocationTargetException;
34import java.lang.reflect.Method;
35import java.lang.reflect.Parameter;
36import java.net.URI;
37import java.net.URISyntaxException;
38import java.net.http.HttpClient;
39import java.net.http.HttpRequest;
40import java.net.http.HttpResponse;
41import java.nio.charset.StandardCharsets;
42import java.nio.file.Files;
43import java.time.Duration;
44import java.util.*;
45import java.util.stream.Collectors;
46import lombok.Setter;
47import org.slf4j.Logger;
48import org.slf4j.LoggerFactory;
49
55@SuppressWarnings({"DuplicatedCode", "resource", "SpellCheckingInspection"})
56public class Ollama {
57
58 private static final Logger LOG = LoggerFactory.getLogger(Ollama.class);
59
60 private final String host;
61 private Auth auth;
62
63 private final ToolRegistry toolRegistry = new ToolRegistry();
64
71 @Setter private long requestTimeoutSeconds = 10;
72
76 @Setter private int imageURLReadTimeoutSeconds = 10;
77
81 @Setter private int imageURLConnectTimeoutSeconds = 10;
82
89 @Setter private int maxChatToolCallRetries = 3;
90
99 @Setter
100 @SuppressWarnings({"FieldMayBeFinal", "FieldCanBeLocal"})
101 private int numberOfRetriesForModelPull = 0;
102
109 @Setter private boolean metricsEnabled = false;
110
114 public Ollama() {
115 this.host = "http://localhost:11434";
116 }
117
123 public Ollama(String host) {
124 if (host.endsWith("/")) {
125 this.host = host.substring(0, host.length() - 1);
126 } else {
127 this.host = host;
128 }
129 LOG.info("Ollama4j client initialized. Connected to Ollama server at: {}", this.host);
130 }
131
138 public void setBasicAuth(String username, String password) {
139 this.auth = new BasicAuth(username, password);
140 }
141
147 public void setBearerAuth(String bearerToken) {
148 this.auth = new BearerAuth(bearerToken);
149 }
150
157 public boolean ping() throws OllamaException {
158 long startTime = System.currentTimeMillis();
159 String url = "/api/tags";
160 int statusCode = -1;
161 Object out = null;
162 try {
163 HttpClient httpClient = HttpClient.newHttpClient();
164 HttpRequest httpRequest;
165 HttpResponse<String> response;
166 httpRequest =
167 getRequestBuilderDefault(new URI(this.host + url))
168 .header(
169 Constants.HttpConstants.HEADER_KEY_ACCEPT,
170 Constants.HttpConstants.APPLICATION_JSON)
171 .header(
172 Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
173 Constants.HttpConstants.APPLICATION_JSON)
174 .GET()
175 .build();
176 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
177 statusCode = response.statusCode();
178 return statusCode == 200;
179 } catch (InterruptedException ie) {
180 Thread.currentThread().interrupt();
181 throw new OllamaException("Ping interrupted", ie);
182 } catch (Exception e) {
183 throw new OllamaException("Ping failed", e);
184 } finally {
186 url,
187 "",
188 false,
190 false,
191 null,
192 null,
193 startTime,
194 statusCode,
195 out);
196 }
197 }
198
206 long startTime = System.currentTimeMillis();
207 String url = "/api/ps";
208 int statusCode = -1;
209 Object out = null;
210 try {
211 HttpClient httpClient = HttpClient.newHttpClient();
212 HttpRequest httpRequest = null;
213 try {
214 httpRequest =
215 getRequestBuilderDefault(new URI(this.host + url))
216 .header(
217 Constants.HttpConstants.HEADER_KEY_ACCEPT,
218 Constants.HttpConstants.APPLICATION_JSON)
219 .header(
220 Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
221 Constants.HttpConstants.APPLICATION_JSON)
222 .GET()
223 .build();
224 } catch (URISyntaxException e) {
225 throw new OllamaException(e.getMessage(), e);
226 }
227 HttpResponse<String> response = null;
228 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
229 statusCode = response.statusCode();
230 String responseString = response.body();
231 if (statusCode == 200) {
232 return Utils.getObjectMapper()
233 .readValue(responseString, ModelProcessesResult.class);
234 } else {
235 throw new OllamaException(statusCode + " - " + responseString);
236 }
237 } catch (InterruptedException ie) {
238 Thread.currentThread().interrupt();
239 throw new OllamaException("ps interrupted", ie);
240 } catch (Exception e) {
241 throw new OllamaException("ps failed", e);
242 } finally {
244 url,
245 "",
246 false,
248 false,
249 null,
250 null,
251 startTime,
252 statusCode,
253 out);
254 }
255 }
256
263 public List<Model> listModels() throws OllamaException {
264 long startTime = System.currentTimeMillis();
265 String url = "/api/tags";
266 int statusCode = -1;
267 Object out = null;
268 try {
269 HttpClient httpClient = HttpClient.newHttpClient();
270 HttpRequest httpRequest =
271 getRequestBuilderDefault(new URI(this.host + url))
272 .header(
273 Constants.HttpConstants.HEADER_KEY_ACCEPT,
274 Constants.HttpConstants.APPLICATION_JSON)
275 .header(
276 Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
277 Constants.HttpConstants.APPLICATION_JSON)
278 .GET()
279 .build();
280 HttpResponse<String> response =
281 httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
282 statusCode = response.statusCode();
283 String responseString = response.body();
284 if (statusCode == 200) {
285 return Utils.getObjectMapper()
286 .readValue(responseString, ListModelsResponse.class)
287 .getModels();
288 } else {
289 throw new OllamaException(statusCode + " - " + responseString);
290 }
291 } catch (InterruptedException ie) {
292 Thread.currentThread().interrupt();
293 throw new OllamaException("listModels interrupted", ie);
294 } catch (Exception e) {
295 throw new OllamaException(e.getMessage(), e);
296 } finally {
298 url,
299 "",
300 false,
302 false,
303 null,
304 null,
305 startTime,
306 statusCode,
307 out);
308 }
309 }
310
320 private void handlePullRetry(
321 String modelName, int currentRetry, int maxRetries, long baseDelayMillis)
322 throws InterruptedException {
323 int attempt = currentRetry + 1;
324 if (attempt < maxRetries) {
325 long backoffMillis = baseDelayMillis * (1L << currentRetry);
326 LOG.error(
327 "Failed to pull model {}, retrying in {}s... (attempt {}/{})",
328 modelName,
329 backoffMillis / 1000,
330 attempt,
331 maxRetries);
332 try {
333 Thread.sleep(backoffMillis);
334 } catch (InterruptedException ie) {
335 Thread.currentThread().interrupt();
336 throw ie;
337 }
338 } else {
339 LOG.error(
340 "Failed to pull model {} after {} attempts, no more retries.",
341 modelName,
342 maxRetries);
343 }
344 }
345
352 private void doPullModel(String modelName) throws OllamaException {
353 long startTime = System.currentTimeMillis();
354 String url = "/api/pull";
355 int statusCode = -1;
356 Object out = null;
357 try {
358 String jsonData = new ModelRequest(modelName).toString();
359 HttpRequest request =
360 getRequestBuilderDefault(new URI(this.host + url))
361 .POST(HttpRequest.BodyPublishers.ofString(jsonData))
362 .header(
363 Constants.HttpConstants.HEADER_KEY_ACCEPT,
364 Constants.HttpConstants.APPLICATION_JSON)
365 .header(
366 Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
367 Constants.HttpConstants.APPLICATION_JSON)
368 .build();
369 HttpClient client = HttpClient.newHttpClient();
370 HttpResponse<InputStream> response =
371 client.send(request, HttpResponse.BodyHandlers.ofInputStream());
372 statusCode = response.statusCode();
373 InputStream responseBodyStream = response.body();
374 String responseString = "";
375 boolean success = false; // Flag to check the pull success.
376
377 try (BufferedReader reader =
378 new BufferedReader(
379 new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
380 String line;
381 while ((line = reader.readLine()) != null) {
382 ModelPullResponse modelPullResponse =
383 Utils.getObjectMapper().readValue(line, ModelPullResponse.class);
384 success = processModelPullResponse(modelPullResponse, modelName) || success;
385 }
386 }
387 if (!success) {
388 LOG.error("Model pull failed or returned invalid status.");
389 throw new OllamaException("Model pull failed or returned invalid status.");
390 }
391 if (statusCode != 200) {
392 throw new OllamaException(statusCode + " - " + responseString);
393 }
394 } catch (InterruptedException ie) {
395 Thread.currentThread().interrupt();
396 throw new OllamaException("Thread was interrupted during model pull.", ie);
397 } catch (Exception e) {
398 throw new OllamaException(e.getMessage(), e);
399 } finally {
401 url,
402 "",
403 false,
405 false,
406 null,
407 null,
408 startTime,
409 statusCode,
410 out);
411 }
412 }
413
423 @SuppressWarnings("RedundantIfStatement")
424 private boolean processModelPullResponse(ModelPullResponse modelPullResponse, String modelName)
425 throws OllamaException {
426 if (modelPullResponse == null) {
427 LOG.error("Received null response for model pull.");
428 return false;
429 }
430 String error = modelPullResponse.getError();
431 if (error != null && !error.trim().isEmpty()) {
432 throw new OllamaException("Model pull failed: " + error);
433 }
434 String status = modelPullResponse.getStatus();
435 if (status != null) {
436 LOG.debug("{}: {}", modelName, status);
437 if ("success".equalsIgnoreCase(status)) {
438 return true;
439 }
440 }
441 return false;
442 }
443
450 public String getVersion() throws OllamaException {
451 String url = "/api/version";
452 long startTime = System.currentTimeMillis();
453 int statusCode = -1;
454 Object out = null;
455 try {
456 HttpClient httpClient = HttpClient.newHttpClient();
457 HttpRequest httpRequest =
458 getRequestBuilderDefault(new URI(this.host + url))
459 .header(
460 Constants.HttpConstants.HEADER_KEY_ACCEPT,
461 Constants.HttpConstants.APPLICATION_JSON)
462 .header(
463 Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
464 Constants.HttpConstants.APPLICATION_JSON)
465 .GET()
466 .build();
467 HttpResponse<String> response =
468 httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
469 statusCode = response.statusCode();
470 String responseString = response.body();
471 if (statusCode == 200) {
472 return Utils.getObjectMapper()
473 .readValue(responseString, OllamaVersion.class)
474 .getVersion();
475 } else {
476 throw new OllamaException(statusCode + " - " + responseString);
477 }
478 } catch (InterruptedException ie) {
479 Thread.currentThread().interrupt();
480 throw new OllamaException("Thread was interrupted", ie);
481 } catch (Exception e) {
482 throw new OllamaException(e.getMessage(), e);
483 } finally {
485 url,
486 "",
487 false,
489 false,
490 null,
491 null,
492 startTime,
493 statusCode,
494 out);
495 }
496 }
497
506 public void pullModel(String modelName) throws OllamaException {
507 try {
508 if (numberOfRetriesForModelPull == 0) {
509 this.doPullModel(modelName);
510 return;
511 }
512 int numberOfRetries = 0;
513 long baseDelayMillis = 3000L; // 3 seconds base delay
514 while (numberOfRetries < numberOfRetriesForModelPull) {
515 try {
516 this.doPullModel(modelName);
517 return;
518 } catch (OllamaException e) {
519 handlePullRetry(
520 modelName,
521 numberOfRetries,
522 numberOfRetriesForModelPull,
523 baseDelayMillis);
524 numberOfRetries++;
525 }
526 }
527 throw new OllamaException(
528 "Failed to pull model "
529 + modelName
530 + " after "
531 + numberOfRetriesForModelPull
532 + " retries");
533 } catch (InterruptedException ie) {
534 Thread.currentThread().interrupt();
535 throw new OllamaException("Thread was interrupted", ie);
536 } catch (Exception e) {
537 throw new OllamaException(e.getMessage(), e);
538 }
539 }
540
548 public ModelDetail getModelDetails(String modelName) throws OllamaException {
549 long startTime = System.currentTimeMillis();
550 String url = "/api/show";
551 int statusCode = -1;
552 Object out = null;
553 try {
554 String jsonData = new ModelRequest(modelName).toString();
555 HttpRequest request =
556 getRequestBuilderDefault(new URI(this.host + url))
557 .header(
558 Constants.HttpConstants.HEADER_KEY_ACCEPT,
559 Constants.HttpConstants.APPLICATION_JSON)
560 .header(
561 Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
562 Constants.HttpConstants.APPLICATION_JSON)
563 .POST(HttpRequest.BodyPublishers.ofString(jsonData))
564 .build();
565 HttpClient client = HttpClient.newHttpClient();
566 HttpResponse<String> response =
567 client.send(request, HttpResponse.BodyHandlers.ofString());
568 statusCode = response.statusCode();
569 String responseBody = response.body();
570 if (statusCode == 200) {
571 return Utils.getObjectMapper().readValue(responseBody, ModelDetail.class);
572 } else {
573 throw new OllamaException(statusCode + " - " + responseBody);
574 }
575 } catch (InterruptedException ie) {
576 Thread.currentThread().interrupt();
577 throw new OllamaException("Thread was interrupted", ie);
578 } catch (Exception e) {
579 throw new OllamaException(e.getMessage(), e);
580 } finally {
582 url,
583 "",
584 false,
586 false,
587 null,
588 null,
589 startTime,
590 statusCode,
591 out);
592 }
593 }
594
602 public void createModel(CustomModelRequest customModelRequest) throws OllamaException {
603 long startTime = System.currentTimeMillis();
604 String url = "/api/create";
605 int statusCode = -1;
606 Object out = null;
607 try {
608 String jsonData = customModelRequest.toString();
609 HttpRequest request =
610 getRequestBuilderDefault(new URI(this.host + url))
611 .header(
612 Constants.HttpConstants.HEADER_KEY_ACCEPT,
613 Constants.HttpConstants.APPLICATION_JSON)
614 .header(
615 Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
616 Constants.HttpConstants.APPLICATION_JSON)
617 .POST(
618 HttpRequest.BodyPublishers.ofString(
619 jsonData, StandardCharsets.UTF_8))
620 .build();
621 HttpClient client = HttpClient.newHttpClient();
622 HttpResponse<InputStream> response =
623 client.send(request, HttpResponse.BodyHandlers.ofInputStream());
624 statusCode = response.statusCode();
625 if (statusCode != 200) {
626 String errorBody =
627 new String(response.body().readAllBytes(), StandardCharsets.UTF_8);
628 out = errorBody;
629 throw new OllamaException(statusCode + " - " + errorBody);
630 }
631 try (BufferedReader reader =
632 new BufferedReader(
633 new InputStreamReader(response.body(), StandardCharsets.UTF_8))) {
634 String line;
635 StringBuilder lines = new StringBuilder();
636 while ((line = reader.readLine()) != null) {
638 Utils.getObjectMapper().readValue(line, ModelPullResponse.class);
639 lines.append(line);
640 LOG.debug(res.getStatus());
641 if (res.getError() != null) {
642 out = res.getError();
643 throw new OllamaException(res.getError());
644 }
645 }
646 out = lines;
647 }
648 } catch (InterruptedException e) {
649 Thread.currentThread().interrupt();
650 throw new OllamaException("Thread was interrupted", e);
651 } catch (Exception e) {
652 throw new OllamaException(e.getMessage(), e);
653 } finally {
655 url,
656 "",
657 false,
659 false,
660 null,
661 null,
662 startTime,
663 statusCode,
664 out);
665 }
666 }
667
675 public void deleteModel(String modelName, boolean ignoreIfNotPresent) throws OllamaException {
676 long startTime = System.currentTimeMillis();
677 String url = "/api/delete";
678 int statusCode = -1;
679 Object out = null;
680 try {
681 String jsonData = new ModelRequest(modelName).toString();
682 HttpRequest request =
683 getRequestBuilderDefault(new URI(this.host + url))
684 .method(
685 "DELETE",
686 HttpRequest.BodyPublishers.ofString(
687 jsonData, StandardCharsets.UTF_8))
688 .header(
689 Constants.HttpConstants.HEADER_KEY_ACCEPT,
690 Constants.HttpConstants.APPLICATION_JSON)
691 .header(
692 Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
693 Constants.HttpConstants.APPLICATION_JSON)
694 .build();
695 HttpClient client = HttpClient.newHttpClient();
696 HttpResponse<String> response =
697 client.send(request, HttpResponse.BodyHandlers.ofString());
698 statusCode = response.statusCode();
699 String responseBody = response.body();
700 out = responseBody;
701 if (statusCode == 404
702 && responseBody.contains("model")
703 && responseBody.contains("not found")) {
704 return;
705 }
706 if (statusCode != 200) {
707 throw new OllamaException(statusCode + " - " + responseBody);
708 }
709 } catch (InterruptedException e) {
710 Thread.currentThread().interrupt();
711 throw new OllamaException("Thread was interrupted", e);
712 } catch (Exception e) {
713 throw new OllamaException(statusCode + " - " + out, e);
714 } finally {
716 url,
717 "",
718 false,
720 false,
721 null,
722 null,
723 startTime,
724 statusCode,
725 out);
726 }
727 }
728
738 public void unloadModel(String modelName) throws OllamaException {
739 long startTime = System.currentTimeMillis();
740 String url = "/api/generate";
741 int statusCode = -1;
742 Object out = null;
743 try {
744 ObjectMapper objectMapper = new ObjectMapper();
745 Map<String, Object> jsonMap = new java.util.HashMap<>();
746 jsonMap.put("model", modelName);
747 jsonMap.put("keep_alive", 0);
748 String jsonData = objectMapper.writeValueAsString(jsonMap);
749 HttpRequest request =
750 getRequestBuilderDefault(new URI(this.host + url))
751 .method(
752 "POST",
753 HttpRequest.BodyPublishers.ofString(
754 jsonData, StandardCharsets.UTF_8))
755 .header(
756 Constants.HttpConstants.HEADER_KEY_ACCEPT,
757 Constants.HttpConstants.APPLICATION_JSON)
758 .header(
759 Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
760 Constants.HttpConstants.APPLICATION_JSON)
761 .build();
762 LOG.debug("Unloading model with request: {}", jsonData);
763 HttpClient client = HttpClient.newHttpClient();
764 HttpResponse<String> response =
765 client.send(request, HttpResponse.BodyHandlers.ofString());
766 statusCode = response.statusCode();
767 String responseBody = response.body();
768 if (statusCode == 404
769 && responseBody.contains("model")
770 && responseBody.contains("not found")) {
771 LOG.debug("Unload response: {} - {}", statusCode, responseBody);
772 return;
773 }
774 if (statusCode != 200) {
775 LOG.debug("Unload response: {} - {}", statusCode, responseBody);
776 throw new OllamaException(statusCode + " - " + responseBody);
777 }
778 } catch (InterruptedException e) {
779 Thread.currentThread().interrupt();
780 LOG.debug("Unload interrupted: {} - {}", statusCode, out);
781 throw new OllamaException(statusCode + " - " + out, e);
782 } catch (Exception e) {
783 LOG.debug("Unload failed: {} - {}", statusCode, out);
784 throw new OllamaException(statusCode + " - " + out, e);
785 } finally {
787 url,
788 "",
789 false,
791 false,
792 null,
793 null,
794 startTime,
795 statusCode,
796 out);
797 }
798 }
799
808 long startTime = System.currentTimeMillis();
809 String url = "/api/embed";
810 int statusCode = -1;
811 Object out = null;
812 try {
813 String jsonData = Utils.getObjectMapper().writeValueAsString(modelRequest);
814 HttpClient httpClient = HttpClient.newHttpClient();
815 HttpRequest request =
816 HttpRequest.newBuilder(new URI(this.host + url))
817 .header(
818 Constants.HttpConstants.HEADER_KEY_ACCEPT,
819 Constants.HttpConstants.APPLICATION_JSON)
820 .POST(HttpRequest.BodyPublishers.ofString(jsonData))
821 .build();
822 HttpResponse<String> response =
823 httpClient.send(request, HttpResponse.BodyHandlers.ofString());
824 statusCode = response.statusCode();
825 String responseBody = response.body();
826 if (statusCode == 200) {
827 return Utils.getObjectMapper().readValue(responseBody, OllamaEmbedResult.class);
828 } else {
829 throw new OllamaException(statusCode + " - " + responseBody);
830 }
831 } catch (InterruptedException e) {
832 Thread.currentThread().interrupt();
833 throw new OllamaException("Thread was interrupted", e);
834 } catch (Exception e) {
835 throw new OllamaException(e.getMessage(), e);
836 } finally {
838 url,
839 "",
840 false,
842 false,
843 null,
844 null,
845 startTime,
846 statusCode,
847 out);
848 }
849 }
850
862 throws OllamaException {
863 try {
864 if (request.isUseTools()) {
865 return generateWithToolsInternal(request, streamObserver);
866 }
867
868 if (streamObserver != null) {
869 if (!request.getThink().equals(ThinkMode.DISABLED)) {
870 return generateSyncForOllamaRequestModel(
871 request,
872 streamObserver.getThinkingStreamHandler(),
873 streamObserver.getResponseStreamHandler());
874 } else {
875 return generateSyncForOllamaRequestModel(
876 request, null, streamObserver.getResponseStreamHandler());
877 }
878 }
879 return generateSyncForOllamaRequestModel(request, null, null);
880 } catch (Exception e) {
881 throw new OllamaException(e.getMessage(), e);
882 }
883 }
884
885 // (No javadoc for private helper, as is standard)
886 private OllamaResult generateWithToolsInternal(
888 throws OllamaException {
889 ArrayList<OllamaChatMessage> msgs = new ArrayList<>();
890 OllamaChatRequest chatRequest = new OllamaChatRequest();
891 chatRequest.setModel(request.getModel());
893 ocm.setRole(OllamaChatMessageRole.USER);
894 ocm.setResponse(request.getPrompt());
895 chatRequest.setMessages(msgs);
896 msgs.add(ocm);
897
898 // Merge request's tools and globally registered tools into a new list to avoid mutating the
899 // original request
900 List<Tools.Tool> allTools = new ArrayList<>();
901 if (request.getTools() != null) {
902 allTools.addAll(request.getTools());
903 }
904 List<Tools.Tool> registeredTools = this.getRegisteredTools();
905 if (registeredTools != null) {
906 allTools.addAll(registeredTools);
907 }
908
909 OllamaChatTokenHandler hdlr = null;
910 chatRequest.setUseTools(true);
911 chatRequest.setTools(allTools);
912 if (streamObserver != null) {
913 chatRequest.setStream(true);
914 if (streamObserver.getResponseStreamHandler() != null) {
915 hdlr =
916 chatResponseModel ->
917 streamObserver
918 .getResponseStreamHandler()
919 .accept(chatResponseModel.getMessage().getResponse());
920 }
921 }
922 OllamaChatResult res = chat(chatRequest, hdlr);
923 return new OllamaResult(
924 res.getResponseModel().getMessage().getResponse(),
925 res.getResponseModel().getMessage().getThinking(),
926 res.getResponseModel().getTotalDuration(),
927 -1);
928 }
929
941 String model, String prompt, boolean raw, ThinkMode think) throws OllamaException {
942 long startTime = System.currentTimeMillis();
943 String url = "/api/generate";
944 int statusCode = -1;
945 try {
946 OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
947 ollamaRequestModel.setRaw(raw);
948 ollamaRequestModel.setThink(think);
949 OllamaAsyncResultStreamer ollamaAsyncResultStreamer =
951 getRequestBuilderDefault(new URI(this.host + url)),
952 ollamaRequestModel,
953 requestTimeoutSeconds);
954 ollamaAsyncResultStreamer.start();
955 statusCode = ollamaAsyncResultStreamer.getHttpStatusCode();
956 return ollamaAsyncResultStreamer;
957 } catch (Exception e) {
958 throw new OllamaException(e.getMessage(), e);
959 } finally {
961 url, model, raw, think, true, null, null, startTime, statusCode, null);
962 }
963 }
964
978 throws OllamaException {
979 try {
980 OllamaChatEndpointCaller requestCaller =
981 new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds);
982 OllamaChatResult result;
983
984 // only add tools if tools flag is set
985 if (request.isUseTools()) {
986 // add all registered tools to request
987 request.getTools().addAll(toolRegistry.getRegisteredTools());
988 }
989
990 if (tokenHandler != null) {
991 request.setStream(true);
992 result = requestCaller.call(request, tokenHandler);
993 } else {
994 result = requestCaller.callSync(request);
995 }
996
997 // check if toolCallIsWanted
998 List<OllamaChatToolCalls> toolCalls =
999 result.getResponseModel().getMessage().getToolCalls();
1000 int toolCallTries = 0;
1001 while (toolCalls != null
1002 && !toolCalls.isEmpty()
1003 && toolCallTries < maxChatToolCallRetries) {
1004 for (OllamaChatToolCalls toolCall : toolCalls) {
1005 String toolName = toolCall.getFunction().getName();
1006 for (Tools.Tool t : request.getTools()) {
1007 if (t.getToolSpec().getName().equals(toolName)) {
1008 ToolFunction toolFunction = t.getToolFunction();
1009 if (toolFunction == null) {
1010 throw new ToolInvocationException(
1011 "Tool function not found: " + toolName);
1012 }
1013 LOG.debug(
1014 "Invoking tool {} with arguments: {}",
1015 toolCall.getFunction().getName(),
1016 toolCall.getFunction().getArguments());
1017 Map<String, Object> arguments = toolCall.getFunction().getArguments();
1018 Object res = toolFunction.apply(arguments);
1019 String argumentKeys =
1020 arguments.keySet().stream()
1021 .map(Object::toString)
1022 .collect(Collectors.joining(", "));
1023 request.getMessages()
1024 .add(
1027 "[TOOL_RESULTS] "
1028 + toolName
1029 + "("
1030 + argumentKeys
1031 + "): "
1032 + res
1033 + " [/TOOL_RESULTS]"));
1034 }
1035 }
1036 }
1037 if (tokenHandler != null) {
1038 result = requestCaller.call(request, tokenHandler);
1039 } else {
1040 result = requestCaller.callSync(request);
1041 }
1042 toolCalls = result.getResponseModel().getMessage().getToolCalls();
1043 toolCallTries++;
1044 }
1045 return result;
1046 } catch (InterruptedException e) {
1047 Thread.currentThread().interrupt();
1048 throw new OllamaException("Thread was interrupted", e);
1049 } catch (Exception e) {
1050 throw new OllamaException(e.getMessage(), e);
1051 }
1052 }
1053
1059 public void registerTool(Tools.Tool tool) {
1060 toolRegistry.addTool(tool);
1061 LOG.debug("Registered tool: {}", tool.getToolSpec().getName());
1062 }
1063
1070 public void registerTools(List<Tools.Tool> tools) {
1071 toolRegistry.addTools(tools);
1072 }
1073
1075 return toolRegistry.getRegisteredTools();
1076 }
1077
1082 public void deregisterTools() {
1083 toolRegistry.clear();
1084 LOG.debug("All tools have been deregistered.");
1085 }
1086
1096 try {
1097 Class<?> callerClass = null;
1098 try {
1099 callerClass =
1100 Class.forName(Thread.currentThread().getStackTrace()[2].getClassName());
1101 } catch (ClassNotFoundException e) {
1102 throw new OllamaException(e.getMessage(), e);
1103 }
1104
1105 OllamaToolService ollamaToolServiceAnnotation =
1106 callerClass.getDeclaredAnnotation(OllamaToolService.class);
1107 if (ollamaToolServiceAnnotation == null) {
1108 throw new IllegalStateException(
1109 callerClass + " is not annotated as " + OllamaToolService.class);
1110 }
1111
1112 Class<?>[] providers = ollamaToolServiceAnnotation.providers();
1113 for (Class<?> provider : providers) {
1114 registerAnnotatedTools(provider.getDeclaredConstructor().newInstance());
1115 }
1116 } catch (InstantiationException
1117 | NoSuchMethodException
1118 | IllegalAccessException
1119 | InvocationTargetException e) {
1120 throw new OllamaException(e.getMessage());
1121 }
1122 }
1123
1133 public void registerAnnotatedTools(Object object) {
1134 Class<?> objectClass = object.getClass();
1135 Method[] methods = objectClass.getMethods();
1136 for (Method m : methods) {
1137 ToolSpec toolSpec = m.getDeclaredAnnotation(ToolSpec.class);
1138 if (toolSpec == null) {
1139 continue;
1140 }
1141 String operationName = !toolSpec.name().isBlank() ? toolSpec.name() : m.getName();
1142 String operationDesc = !toolSpec.desc().isBlank() ? toolSpec.desc() : operationName;
1143
1144 final Map<String, Tools.Property> params = new HashMap<String, Tools.Property>() {};
1145 LinkedHashMap<String, String> methodParams = new LinkedHashMap<>();
1146 for (Parameter parameter : m.getParameters()) {
1147 final ToolProperty toolPropertyAnn =
1148 parameter.getDeclaredAnnotation(ToolProperty.class);
1149 String propType = parameter.getType().getTypeName();
1150 if (toolPropertyAnn == null) {
1151 methodParams.put(parameter.getName(), null);
1152 continue;
1153 }
1154 String propName =
1155 !toolPropertyAnn.name().isBlank()
1156 ? toolPropertyAnn.name()
1157 : parameter.getName();
1158 methodParams.put(propName, propType);
1159 params.put(
1160 propName,
1161 Tools.Property.builder()
1162 .type(propType)
1163 .description(toolPropertyAnn.desc())
1164 .required(toolPropertyAnn.required())
1165 .build());
1166 }
1167 Tools.ToolSpec toolSpecification =
1168 Tools.ToolSpec.builder()
1169 .name(operationName)
1170 .description(operationDesc)
1171 .parameters(Tools.Parameters.of(params))
1172 .build();
1173 ReflectionalToolFunction reflectionalToolFunction =
1174 new ReflectionalToolFunction(object, m, methodParams);
1175 toolRegistry.addTool(
1176 Tools.Tool.builder()
1177 .toolFunction(reflectionalToolFunction)
1178 .toolSpec(toolSpecification)
1179 .build());
1180 }
1181 }
1182
1189 public OllamaChatMessageRole addCustomRole(String roleName) {
1190 return OllamaChatMessageRole.newCustomRole(roleName);
1191 }
1192
1198 public List<OllamaChatMessageRole> listRoles() {
1200 }
1201
1209 public OllamaChatMessageRole getRole(String roleName) throws RoleNotFoundException {
1210 return OllamaChatMessageRole.getRole(roleName);
1211 }
1212
1213 // technical private methods //
1214
1222 private static String encodeFileToBase64(File file) throws IOException {
1223 return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
1224 }
1225
1232 private static String encodeByteArrayToBase64(byte[] bytes) {
1233 return Base64.getEncoder().encodeToString(bytes);
1234 }
1235
1247 private OllamaResult generateSyncForOllamaRequestModel(
1248 OllamaGenerateRequest ollamaRequestModel,
1249 OllamaGenerateTokenHandler thinkingStreamHandler,
1250 OllamaGenerateTokenHandler responseStreamHandler)
1251 throws OllamaException {
1252 long startTime = System.currentTimeMillis();
1253 int statusCode = -1;
1254 Object out = null;
1255 try {
1256 OllamaGenerateEndpointCaller requestCaller =
1257 new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds);
1258 OllamaResult result;
1259 if (responseStreamHandler != null) {
1260 ollamaRequestModel.setStream(true);
1261 result =
1262 requestCaller.call(
1263 ollamaRequestModel, thinkingStreamHandler, responseStreamHandler);
1264 } else {
1265 result = requestCaller.callSync(ollamaRequestModel);
1266 }
1267 statusCode = result.getHttpStatusCode();
1268 out = result;
1269 return result;
1270 } catch (InterruptedException e) {
1271 Thread.currentThread().interrupt();
1272 throw new OllamaException("Thread was interrupted", e);
1273 } catch (Exception e) {
1274 throw new OllamaException(e.getMessage(), e);
1275 } finally {
1278 ollamaRequestModel.getModel(),
1279 ollamaRequestModel.isRaw(),
1280 ollamaRequestModel.getThink(),
1281 ollamaRequestModel.isStream(),
1282 ollamaRequestModel.getOptions(),
1283 ollamaRequestModel.getFormat(),
1284 startTime,
1285 statusCode,
1286 out);
1287 }
1288 }
1289
1296 private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
1297 HttpRequest.Builder requestBuilder =
1298 HttpRequest.newBuilder(uri)
1299 .header(
1300 Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
1301 Constants.HttpConstants.APPLICATION_JSON)
1302 .timeout(Duration.ofSeconds(requestTimeoutSeconds));
1303 if (isAuthSet()) {
1304 requestBuilder.header("Authorization", auth.getAuthHeaderValue());
1305 }
1306 return requestBuilder;
1307 }
1308
1314 private boolean isAuthSet() {
1315 return auth != null;
1316 }
1317}
void setBasicAuth(String username, String password)
Definition Ollama.java:138
OllamaChatMessageRole getRole(String roleName)
Definition Ollama.java:1209
OllamaEmbedResult embed(OllamaEmbedRequest modelRequest)
Definition Ollama.java:807
OllamaAsyncResultStreamer generateAsync(String model, String prompt, boolean raw, ThinkMode think)
Definition Ollama.java:940
List< Tools.Tool > getRegisteredTools()
Definition Ollama.java:1074
OllamaResult generate(OllamaGenerateRequest request, OllamaGenerateStreamObserver streamObserver)
Definition Ollama.java:860
void unloadModel(String modelName)
Definition Ollama.java:738
void registerTool(Tools.Tool tool)
Definition Ollama.java:1059
void pullModel(String modelName)
Definition Ollama.java:506
OllamaChatResult chat(OllamaChatRequest request, OllamaChatTokenHandler tokenHandler)
Definition Ollama.java:977
List< Model > listModels()
Definition Ollama.java:263
List< OllamaChatMessageRole > listRoles()
Definition Ollama.java:1198
OllamaChatMessageRole addCustomRole(String roleName)
Definition Ollama.java:1189
void setBearerAuth(String bearerToken)
Definition Ollama.java:147
ModelProcessesResult ps()
Definition Ollama.java:205
void createModel(CustomModelRequest customModelRequest)
Definition Ollama.java:602
void registerAnnotatedTools(Object object)
Definition Ollama.java:1133
void registerTools(List< Tools.Tool > tools)
Definition Ollama.java:1070
void deleteModel(String modelName, boolean ignoreIfNotPresent)
Definition Ollama.java:675
ModelDetail getModelDetails(String modelName)
Definition Ollama.java:548
static void record(String endpoint, String model, boolean raw, ThinkMode thinkMode, boolean streaming, Map< String, Object > options, Object format, long startTime, int responseHttpStatus, Object response)
static OllamaChatMessageRole newCustomRole(String roleName)
static OllamaChatMessageRole getRole(String roleName)
OllamaChatResult call(OllamaChatRequest body, OllamaChatTokenHandler tokenHandler)
OllamaResult call(OllamaRequestBody body, OllamaGenerateTokenHandler thinkingStreamHandler, OllamaGenerateTokenHandler responseStreamHandler)
void addTools(List< Tools.Tool > tools)
static ObjectMapper getObjectMapper()
Definition Utils.java:32
Object apply(Map< String, Object > arguments)