Ollama4j
A Java library (wrapper/binding) for Ollama server.
Loading...
Searching...
No Matches
Ollama.java
Go to the documentation of this file.
1/*
2 * Ollama4j - Java library for interacting with Ollama server.
3 * Copyright (c) 2025 Amith Koujalgi and contributors.
4 *
5 * Licensed under the MIT License (the "License");
6 * you may not use this file except in compliance with the License.
7 *
8*/
9package io.github.ollama4j;
10
11import com.fasterxml.jackson.annotation.JsonProperty;
12import com.fasterxml.jackson.databind.ObjectMapper;
13import io.github.ollama4j.exceptions.OllamaException;
14import io.github.ollama4j.exceptions.RoleNotFoundException;
15import io.github.ollama4j.exceptions.ToolInvocationException;
16import io.github.ollama4j.metrics.MetricsRecorder;
17import io.github.ollama4j.models.chat.*;
18import io.github.ollama4j.models.embed.OllamaEmbedRequest;
19import io.github.ollama4j.models.embed.OllamaEmbedResult;
20import io.github.ollama4j.models.generate.OllamaGenerateRequest;
21import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver;
22import io.github.ollama4j.models.generate.OllamaGenerateTokenHandler;
23import io.github.ollama4j.models.ps.ModelProcessesResult;
24import io.github.ollama4j.models.request.*;
25import io.github.ollama4j.models.response.*;
26import io.github.ollama4j.tools.*;
27import io.github.ollama4j.tools.annotations.OllamaToolService;
28import io.github.ollama4j.tools.annotations.ToolProperty;
29import io.github.ollama4j.tools.annotations.ToolSpec;
30import io.github.ollama4j.utils.Constants;
31import io.github.ollama4j.utils.Utils;
32import io.modelcontextprotocol.client.McpClient;
33import io.modelcontextprotocol.client.McpSyncClient;
34import io.modelcontextprotocol.client.transport.ServerParameters;
35import io.modelcontextprotocol.client.transport.StdioClientTransport;
36import io.modelcontextprotocol.json.McpJsonMapper;
37import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
38import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
39import io.modelcontextprotocol.spec.McpSchema.ListToolsResult;
40import java.io.*;
41import java.lang.reflect.InvocationTargetException;
42import java.lang.reflect.Method;
43import java.lang.reflect.Parameter;
44import java.net.URI;
45import java.net.URISyntaxException;
46import java.net.http.HttpClient;
47import java.net.http.HttpRequest;
48import java.net.http.HttpResponse;
49import java.nio.charset.StandardCharsets;
50import java.nio.file.Files;
51import java.time.Duration;
52import java.util.*;
53import java.util.stream.Collectors;
54import lombok.AllArgsConstructor;
55import lombok.Data;
56import lombok.NoArgsConstructor;
57import lombok.Setter;
58import org.slf4j.Logger;
59import org.slf4j.LoggerFactory;
60
67@SuppressWarnings({"DuplicatedCode", "resource", "SpellCheckingInspection"})
68public class Ollama {
69
70 private static final Logger LOG = LoggerFactory.getLogger(Ollama.class);
71
72 private final String host;
73 private Auth auth;
74
75 private final ToolRegistry toolRegistry = new ToolRegistry();
76
83 @Setter private long requestTimeoutSeconds = 10;
84
86 @Setter private int imageURLReadTimeoutSeconds = 10;
87
89 @Setter private int imageURLConnectTimeoutSeconds = 10;
90
97 @Setter private int maxChatToolCallRetries = 3;
98
107 @Setter
108 @SuppressWarnings({"FieldMayBeFinal", "FieldCanBeLocal"})
109 private int numberOfRetriesForModelPull = 0;
110
117 @Setter private boolean metricsEnabled = false;
118
120 public Ollama() {
121 this.host = "http://localhost:11434";
122 }
123
129 public Ollama(String host) {
130 if (host.endsWith("/")) {
131 this.host = host.substring(0, host.length() - 1);
132 } else {
133 this.host = host;
134 }
135 LOG.info("Ollama4j client initialized. Connected to Ollama server at: {}", this.host);
136 }
137
144 public void setBasicAuth(String username, String password) {
145 this.auth = new BasicAuth(username, password);
146 }
147
153 public void setBearerAuth(String bearerToken) {
154 this.auth = new BearerAuth(bearerToken);
155 }
156
163 public boolean ping() throws OllamaException {
164 long startTime = System.currentTimeMillis();
165 String url = "/api/tags";
166 int statusCode = -1;
167 Object out = null;
168 try {
169 HttpClient httpClient = HttpClient.newHttpClient();
170 HttpRequest httpRequest;
171 HttpResponse<String> response;
172 httpRequest =
173 getRequestBuilderDefault(new URI(this.host + url))
174 .header(
175 Constants.HttpConstants.HEADER_KEY_ACCEPT,
176 Constants.HttpConstants.APPLICATION_JSON)
177 .header(
178 Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
179 Constants.HttpConstants.APPLICATION_JSON)
180 .GET()
181 .build();
182 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
183 statusCode = response.statusCode();
184 return statusCode == 200;
185 } catch (InterruptedException ie) {
186 Thread.currentThread().interrupt();
187 throw new OllamaException("Ping interrupted", ie);
188 } catch (Exception e) {
189 throw new OllamaException("Ping failed", e);
190 } finally {
192 url,
193 "",
194 false,
196 false,
197 null,
198 null,
199 startTime,
200 statusCode,
201 out);
202 }
203 }
204
212 long startTime = System.currentTimeMillis();
213 String url = "/api/ps";
214 int statusCode = -1;
215 Object out = null;
216 try {
217 HttpClient httpClient = HttpClient.newHttpClient();
218 HttpRequest httpRequest = null;
219 try {
220 httpRequest =
221 getRequestBuilderDefault(new URI(this.host + url))
222 .header(
223 Constants.HttpConstants.HEADER_KEY_ACCEPT,
224 Constants.HttpConstants.APPLICATION_JSON)
225 .header(
226 Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
227 Constants.HttpConstants.APPLICATION_JSON)
228 .GET()
229 .build();
230 } catch (URISyntaxException e) {
231 throw new OllamaException(e.getMessage(), e);
232 }
233 HttpResponse<String> response = null;
234 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
235 statusCode = response.statusCode();
236 String responseString = response.body();
237 if (statusCode == 200) {
238 return Utils.getObjectMapper()
239 .readValue(responseString, ModelProcessesResult.class);
240 } else {
241 throw new OllamaException(statusCode + " - " + responseString);
242 }
243 } catch (InterruptedException ie) {
244 Thread.currentThread().interrupt();
245 throw new OllamaException("ps interrupted", ie);
246 } catch (Exception e) {
247 throw new OllamaException("ps failed", e);
248 } finally {
250 url,
251 "",
252 false,
254 false,
255 null,
256 null,
257 startTime,
258 statusCode,
259 out);
260 }
261 }
262
269 public List<Model> listModels() throws OllamaException {
270 long startTime = System.currentTimeMillis();
271 String url = "/api/tags";
272 int statusCode = -1;
273 Object out = null;
274 try {
275 HttpClient httpClient = HttpClient.newHttpClient();
276 HttpRequest httpRequest =
277 getRequestBuilderDefault(new URI(this.host + url))
278 .header(
279 Constants.HttpConstants.HEADER_KEY_ACCEPT,
280 Constants.HttpConstants.APPLICATION_JSON)
281 .header(
282 Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
283 Constants.HttpConstants.APPLICATION_JSON)
284 .GET()
285 .build();
286 HttpResponse<String> response =
287 httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
288 statusCode = response.statusCode();
289 String responseString = response.body();
290 if (statusCode == 200) {
291 return Utils.getObjectMapper()
292 .readValue(responseString, ListModelsResponse.class)
293 .getModels();
294 } else {
295 throw new OllamaException(statusCode + " - " + responseString);
296 }
297 } catch (InterruptedException ie) {
298 Thread.currentThread().interrupt();
299 throw new OllamaException("listModels interrupted", ie);
300 } catch (Exception e) {
301 throw new OllamaException(e.getMessage(), e);
302 } finally {
304 url,
305 "",
306 false,
308 false,
309 null,
310 null,
311 startTime,
312 statusCode,
313 out);
314 }
315 }
316
326 private void handlePullRetry(
327 String modelName, int currentRetry, int maxRetries, long baseDelayMillis)
328 throws InterruptedException {
329 int attempt = currentRetry + 1;
330 if (attempt < maxRetries) {
331 long backoffMillis = baseDelayMillis * (1L << currentRetry);
332 LOG.error(
333 "Failed to pull model {}, retrying in {}s... (attempt {}/{})",
334 modelName,
335 backoffMillis / 1000,
336 attempt,
337 maxRetries);
338 try {
339 Thread.sleep(backoffMillis);
340 } catch (InterruptedException ie) {
341 Thread.currentThread().interrupt();
342 throw ie;
343 }
344 } else {
345 LOG.error(
346 "Failed to pull model {} after {} attempts, no more retries.",
347 modelName,
348 maxRetries);
349 }
350 }
351
358 private void doPullModel(String modelName) throws OllamaException {
359 long startTime = System.currentTimeMillis();
360 String url = "/api/pull";
361 int statusCode = -1;
362 Object out = null;
363 try {
364 String jsonData = new ModelRequest(modelName).toString();
365 HttpRequest request =
366 getRequestBuilderDefault(new URI(this.host + url))
367 .POST(HttpRequest.BodyPublishers.ofString(jsonData))
368 .header(
369 Constants.HttpConstants.HEADER_KEY_ACCEPT,
370 Constants.HttpConstants.APPLICATION_JSON)
371 .header(
372 Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
373 Constants.HttpConstants.APPLICATION_JSON)
374 .build();
375 HttpClient client = HttpClient.newHttpClient();
376 HttpResponse<InputStream> response =
377 client.send(request, HttpResponse.BodyHandlers.ofInputStream());
378 statusCode = response.statusCode();
379 InputStream responseBodyStream = response.body();
380 String responseString = "";
381 boolean success = false; // Flag to check the pull success.
382
383 try (BufferedReader reader =
384 new BufferedReader(
385 new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
386 String line;
387 while ((line = reader.readLine()) != null) {
388 ModelPullResponse modelPullResponse =
389 Utils.getObjectMapper().readValue(line, ModelPullResponse.class);
390 success = processModelPullResponse(modelPullResponse, modelName) || success;
391 }
392 }
393 if (!success) {
394 LOG.error("Model pull failed or returned invalid status.");
395 throw new OllamaException("Model pull failed or returned invalid status.");
396 }
397 if (statusCode != 200) {
398 throw new OllamaException(statusCode + " - " + responseString);
399 }
400 } catch (InterruptedException ie) {
401 Thread.currentThread().interrupt();
402 throw new OllamaException("Thread was interrupted during model pull.", ie);
403 } catch (Exception e) {
404 throw new OllamaException(e.getMessage(), e);
405 } finally {
407 url,
408 "",
409 false,
411 false,
412 null,
413 null,
414 startTime,
415 statusCode,
416 out);
417 }
418 }
419
429 @SuppressWarnings("RedundantIfStatement")
430 private boolean processModelPullResponse(ModelPullResponse modelPullResponse, String modelName)
431 throws OllamaException {
432 if (modelPullResponse == null) {
433 LOG.error("Received null response for model pull.");
434 return false;
435 }
436 String error = modelPullResponse.getError();
437 if (error != null && !error.trim().isEmpty()) {
438 throw new OllamaException("Model pull failed: " + error);
439 }
440 String status = modelPullResponse.getStatus();
441 if (status != null) {
442 LOG.debug("{}: {}", modelName, status);
443 if ("success".equalsIgnoreCase(status)) {
444 return true;
445 }
446 }
447 return false;
448 }
449
456 public String getVersion() throws OllamaException {
457 String url = "/api/version";
458 long startTime = System.currentTimeMillis();
459 int statusCode = -1;
460 Object out = null;
461 try {
462 HttpClient httpClient = HttpClient.newHttpClient();
463 HttpRequest httpRequest =
464 getRequestBuilderDefault(new URI(this.host + url))
465 .header(
466 Constants.HttpConstants.HEADER_KEY_ACCEPT,
467 Constants.HttpConstants.APPLICATION_JSON)
468 .header(
469 Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
470 Constants.HttpConstants.APPLICATION_JSON)
471 .GET()
472 .build();
473 HttpResponse<String> response =
474 httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
475 statusCode = response.statusCode();
476 String responseString = response.body();
477 if (statusCode == 200) {
478 return Utils.getObjectMapper()
479 .readValue(responseString, OllamaVersion.class)
480 .getVersion();
481 } else {
482 throw new OllamaException(statusCode + " - " + responseString);
483 }
484 } catch (InterruptedException ie) {
485 Thread.currentThread().interrupt();
486 throw new OllamaException("Thread was interrupted", ie);
487 } catch (Exception e) {
488 throw new OllamaException(e.getMessage(), e);
489 } finally {
491 url,
492 "",
493 false,
495 false,
496 null,
497 null,
498 startTime,
499 statusCode,
500 out);
501 }
502 }
503
512 public void pullModel(String modelName) throws OllamaException {
513 try {
514 if (numberOfRetriesForModelPull == 0) {
515 this.doPullModel(modelName);
516 return;
517 }
518 int numberOfRetries = 0;
519 long baseDelayMillis = 3000L; // 3 seconds base delay
520 while (numberOfRetries < numberOfRetriesForModelPull) {
521 try {
522 this.doPullModel(modelName);
523 return;
524 } catch (OllamaException e) {
525 handlePullRetry(
526 modelName,
527 numberOfRetries,
528 numberOfRetriesForModelPull,
529 baseDelayMillis);
530 numberOfRetries++;
531 }
532 }
533 throw new OllamaException(
534 "Failed to pull model "
535 + modelName
536 + " after "
537 + numberOfRetriesForModelPull
538 + " retries");
539 } catch (InterruptedException ie) {
540 Thread.currentThread().interrupt();
541 throw new OllamaException("Thread was interrupted", ie);
542 } catch (Exception e) {
543 throw new OllamaException(e.getMessage(), e);
544 }
545 }
546
554 public ModelDetail getModelDetails(String modelName) throws OllamaException {
555 long startTime = System.currentTimeMillis();
556 String url = "/api/show";
557 int statusCode = -1;
558 Object out = null;
559 try {
560 String jsonData = new ModelRequest(modelName).toString();
561 HttpRequest request =
562 getRequestBuilderDefault(new URI(this.host + url))
563 .header(
564 Constants.HttpConstants.HEADER_KEY_ACCEPT,
565 Constants.HttpConstants.APPLICATION_JSON)
566 .header(
567 Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
568 Constants.HttpConstants.APPLICATION_JSON)
569 .POST(HttpRequest.BodyPublishers.ofString(jsonData))
570 .build();
571 HttpClient client = HttpClient.newHttpClient();
572 HttpResponse<String> response =
573 client.send(request, HttpResponse.BodyHandlers.ofString());
574 statusCode = response.statusCode();
575 String responseBody = response.body();
576 if (statusCode == 200) {
577 return Utils.getObjectMapper().readValue(responseBody, ModelDetail.class);
578 } else {
579 throw new OllamaException(statusCode + " - " + responseBody);
580 }
581 } catch (InterruptedException ie) {
582 Thread.currentThread().interrupt();
583 throw new OllamaException("Thread was interrupted", ie);
584 } catch (Exception e) {
585 throw new OllamaException(e.getMessage(), e);
586 } finally {
588 url,
589 "",
590 false,
592 false,
593 null,
594 null,
595 startTime,
596 statusCode,
597 out);
598 }
599 }
600
608 public void createModel(CustomModelRequest customModelRequest) throws OllamaException {
609 long startTime = System.currentTimeMillis();
610 String url = "/api/create";
611 int statusCode = -1;
612 Object out = null;
613 try {
614 String jsonData = customModelRequest.toString();
615 HttpRequest request =
616 getRequestBuilderDefault(new URI(this.host + url))
617 .header(
618 Constants.HttpConstants.HEADER_KEY_ACCEPT,
619 Constants.HttpConstants.APPLICATION_JSON)
620 .header(
621 Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
622 Constants.HttpConstants.APPLICATION_JSON)
623 .POST(
624 HttpRequest.BodyPublishers.ofString(
625 jsonData, StandardCharsets.UTF_8))
626 .build();
627 HttpClient client = HttpClient.newHttpClient();
628 HttpResponse<InputStream> response =
629 client.send(request, HttpResponse.BodyHandlers.ofInputStream());
630 statusCode = response.statusCode();
631 if (statusCode != 200) {
632 String errorBody =
633 new String(response.body().readAllBytes(), StandardCharsets.UTF_8);
634 out = errorBody;
635 throw new OllamaException(statusCode + " - " + errorBody);
636 }
637 try (BufferedReader reader =
638 new BufferedReader(
639 new InputStreamReader(response.body(), StandardCharsets.UTF_8))) {
640 String line;
641 StringBuilder lines = new StringBuilder();
642 while ((line = reader.readLine()) != null) {
644 Utils.getObjectMapper().readValue(line, ModelPullResponse.class);
645 lines.append(line);
646 LOG.debug(res.getStatus());
647 if (res.getError() != null) {
648 out = res.getError();
649 throw new OllamaException(res.getError());
650 }
651 }
652 out = lines;
653 }
654 } catch (InterruptedException e) {
655 Thread.currentThread().interrupt();
656 throw new OllamaException("Thread was interrupted", e);
657 } catch (Exception e) {
658 throw new OllamaException(e.getMessage(), e);
659 } finally {
661 url,
662 "",
663 false,
665 false,
666 null,
667 null,
668 startTime,
669 statusCode,
670 out);
671 }
672 }
673
682 public void deleteModel(String modelName, boolean ignoreIfNotPresent) throws OllamaException {
683 long startTime = System.currentTimeMillis();
684 String url = "/api/delete";
685 int statusCode = -1;
686 Object out = null;
687 try {
688 String jsonData = new ModelRequest(modelName).toString();
689 HttpRequest request =
690 getRequestBuilderDefault(new URI(this.host + url))
691 .method(
692 "DELETE",
693 HttpRequest.BodyPublishers.ofString(
694 jsonData, StandardCharsets.UTF_8))
695 .header(
696 Constants.HttpConstants.HEADER_KEY_ACCEPT,
697 Constants.HttpConstants.APPLICATION_JSON)
698 .header(
699 Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
700 Constants.HttpConstants.APPLICATION_JSON)
701 .build();
702 HttpClient client = HttpClient.newHttpClient();
703 HttpResponse<String> response =
704 client.send(request, HttpResponse.BodyHandlers.ofString());
705 statusCode = response.statusCode();
706 String responseBody = response.body();
707 out = responseBody;
708 if (statusCode == 404
709 && responseBody.contains("model")
710 && responseBody.contains("not found")) {
711 return;
712 }
713 if (statusCode != 200) {
714 throw new OllamaException(statusCode + " - " + responseBody);
715 }
716 } catch (InterruptedException e) {
717 Thread.currentThread().interrupt();
718 throw new OllamaException("Thread was interrupted", e);
719 } catch (Exception e) {
720 throw new OllamaException(statusCode + " - " + out, e);
721 } finally {
723 url,
724 "",
725 false,
727 false,
728 null,
729 null,
730 startTime,
731 statusCode,
732 out);
733 }
734 }
735
745 public void unloadModel(String modelName) throws OllamaException {
746 long startTime = System.currentTimeMillis();
747 String url = "/api/generate";
748 int statusCode = -1;
749 Object out = null;
750 try {
751 ObjectMapper objectMapper = new ObjectMapper();
752 Map<String, Object> jsonMap = new java.util.HashMap<>();
753 jsonMap.put("model", modelName);
754 jsonMap.put("keep_alive", 0);
755 String jsonData = objectMapper.writeValueAsString(jsonMap);
756 HttpRequest request =
757 getRequestBuilderDefault(new URI(this.host + url))
758 .method(
759 "POST",
760 HttpRequest.BodyPublishers.ofString(
761 jsonData, StandardCharsets.UTF_8))
762 .header(
763 Constants.HttpConstants.HEADER_KEY_ACCEPT,
764 Constants.HttpConstants.APPLICATION_JSON)
765 .header(
766 Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
767 Constants.HttpConstants.APPLICATION_JSON)
768 .build();
769 LOG.debug("Unloading model with request: {}", jsonData);
770 HttpClient client = HttpClient.newHttpClient();
771 HttpResponse<String> response =
772 client.send(request, HttpResponse.BodyHandlers.ofString());
773 statusCode = response.statusCode();
774 String responseBody = response.body();
775 if (statusCode == 404
776 && responseBody.contains("model")
777 && responseBody.contains("not found")) {
778 LOG.debug("Unload response: {} - {}", statusCode, responseBody);
779 return;
780 }
781 if (statusCode != 200) {
782 LOG.debug("Unload response: {} - {}", statusCode, responseBody);
783 throw new OllamaException(statusCode + " - " + responseBody);
784 }
785 } catch (InterruptedException e) {
786 Thread.currentThread().interrupt();
787 LOG.debug("Unload interrupted: {} - {}", statusCode, out);
788 throw new OllamaException(statusCode + " - " + out, e);
789 } catch (Exception e) {
790 LOG.debug("Unload failed: {} - {}", statusCode, out);
791 throw new OllamaException(statusCode + " - " + out, e);
792 } finally {
794 url,
795 "",
796 false,
798 false,
799 null,
800 null,
801 startTime,
802 statusCode,
803 out);
804 }
805 }
806
815 long startTime = System.currentTimeMillis();
816 String url = "/api/embed";
817 int statusCode = -1;
818 Object out = null;
819 try {
820 String jsonData = Utils.getObjectMapper().writeValueAsString(modelRequest);
821 HttpClient httpClient = HttpClient.newHttpClient();
822 HttpRequest request =
823 HttpRequest.newBuilder(new URI(this.host + url))
824 .header(
825 Constants.HttpConstants.HEADER_KEY_ACCEPT,
826 Constants.HttpConstants.APPLICATION_JSON)
827 .POST(HttpRequest.BodyPublishers.ofString(jsonData))
828 .build();
829 HttpResponse<String> response =
830 httpClient.send(request, HttpResponse.BodyHandlers.ofString());
831 statusCode = response.statusCode();
832 String responseBody = response.body();
833 if (statusCode == 200) {
834 return Utils.getObjectMapper().readValue(responseBody, OllamaEmbedResult.class);
835 } else {
836 throw new OllamaException(statusCode + " - " + responseBody);
837 }
838 } catch (InterruptedException e) {
839 Thread.currentThread().interrupt();
840 throw new OllamaException("Thread was interrupted", e);
841 } catch (Exception e) {
842 throw new OllamaException(e.getMessage(), e);
843 } finally {
845 url,
846 "",
847 false,
849 false,
850 null,
851 null,
852 startTime,
853 statusCode,
854 out);
855 }
856 }
857
869 throws OllamaException {
870 try {
871 if (request.isUseTools()) {
872 return generateWithToolsInternal(request, streamObserver);
873 }
874
875 if (streamObserver != null) {
876 if (!request.getThink().equals(ThinkMode.DISABLED)) {
877 return generateSyncForOllamaRequestModel(
878 request,
879 streamObserver.getThinkingStreamHandler(),
880 streamObserver.getResponseStreamHandler());
881 } else {
882 return generateSyncForOllamaRequestModel(
883 request, null, streamObserver.getResponseStreamHandler());
884 }
885 }
886 return generateSyncForOllamaRequestModel(request, null, null);
887 } catch (Exception e) {
888 throw new OllamaException(e.getMessage(), e);
889 }
890 }
891
892 // (No javadoc for private helper, as is standard)
893 private OllamaResult generateWithToolsInternal(
895 throws OllamaException {
896 ArrayList<OllamaChatMessage> msgs = new ArrayList<>();
897 OllamaChatRequest chatRequest = new OllamaChatRequest();
898 chatRequest.setModel(request.getModel());
900 ocm.setRole(OllamaChatMessageRole.USER);
901 ocm.setResponse(request.getPrompt());
902 chatRequest.setMessages(msgs);
903 msgs.add(ocm);
904
905 // Merge request's tools and globally registered tools into a new list to avoid
906 // mutating the
907 // original request
908 List<Tools.Tool> allTools = new ArrayList<>();
909 if (request.getTools() != null) {
910 allTools.addAll(request.getTools());
911 }
912 List<Tools.Tool> registeredTools = this.getRegisteredTools();
913 if (registeredTools != null) {
914 allTools.addAll(registeredTools);
915 }
916
917 OllamaChatTokenHandler hdlr = null;
918 chatRequest.setUseTools(true);
919 chatRequest.setTools(allTools);
920 if (streamObserver != null) {
921 chatRequest.setStream(true);
922 if (streamObserver.getResponseStreamHandler() != null) {
923 hdlr =
924 chatResponseModel ->
925 streamObserver
926 .getResponseStreamHandler()
927 .accept(chatResponseModel.getMessage().getResponse());
928 }
929 }
930 OllamaChatResult res = chat(chatRequest, hdlr);
931 return new OllamaResult(
932 res.getResponseModel().getMessage().getResponse(),
933 res.getResponseModel().getMessage().getThinking(),
934 res.getResponseModel().getTotalDuration(),
935 -1);
936 }
937
949 String model, String prompt, boolean raw, ThinkMode think) throws OllamaException {
950 long startTime = System.currentTimeMillis();
951 String url = "/api/generate";
952 int statusCode = -1;
953 try {
954 OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
955 ollamaRequestModel.setRaw(raw);
956 ollamaRequestModel.setThink(think);
957 OllamaAsyncResultStreamer ollamaAsyncResultStreamer =
959 getRequestBuilderDefault(new URI(this.host + url)),
960 ollamaRequestModel,
961 requestTimeoutSeconds);
962 ollamaAsyncResultStreamer.start();
963 statusCode = ollamaAsyncResultStreamer.getHttpStatusCode();
964 return ollamaAsyncResultStreamer;
965 } catch (Exception e) {
966 throw new OllamaException(e.getMessage(), e);
967 } finally {
969 url, model, raw, think, true, null, null, startTime, statusCode, null);
970 }
971 }
972
986 throws OllamaException {
987 try {
988 OllamaChatEndpointCaller requestCaller =
989 new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds);
990 OllamaChatResult result;
991
992 // only add tools if tools flag is set
993 if (request.isUseTools()) {
994 // add all registered tools to request
995 request.getTools().addAll(toolRegistry.getRegisteredTools());
996 }
997
998 if (tokenHandler != null) {
999 request.setStream(true);
1000 result = requestCaller.call(request, tokenHandler);
1001 } else {
1002 result = requestCaller.callSync(request);
1003 }
1004
1005 // check if toolCallIsWanted
1006 List<OllamaChatToolCalls> toolCalls =
1007 result.getResponseModel().getMessage().getToolCalls();
1008
1009 int toolCallTries = 0;
1010 while (toolCalls != null
1011 && !toolCalls.isEmpty()
1012 && toolCallTries < maxChatToolCallRetries) {
1013 for (OllamaChatToolCalls toolCall : toolCalls) {
1014 String toolName = toolCall.getFunction().getName();
1015 for (Tools.Tool t : request.getTools()) {
1016 if (t.getToolSpec().getName().equals(toolName)) {
1017 ToolFunction toolFunction = t.getToolFunction();
1018 if (toolFunction == null) {
1019 throw new ToolInvocationException(
1020 "Tool function not found: " + toolName);
1021 }
1022 LOG.debug(
1023 "Invoking tool {} with arguments: {}",
1024 toolCall.getFunction().getName(),
1025 toolCall.getFunction().getArguments());
1026 Map<String, Object> arguments = toolCall.getFunction().getArguments();
1027 Object res = toolFunction.apply(arguments);
1028 String argumentKeys =
1029 arguments.keySet().stream()
1030 .map(Object::toString)
1031 .collect(Collectors.joining(", "));
1032 request.getMessages()
1033 .add(
1036 "[TOOL_RESULTS] "
1037 + toolName
1038 + "("
1039 + argumentKeys
1040 + "): "
1041 + res
1042 + " [/TOOL_RESULTS]"));
1043 }
1044 }
1045 }
1046 if (tokenHandler != null) {
1047 result = requestCaller.call(request, tokenHandler);
1048 } else {
1049 result = requestCaller.callSync(request);
1050 }
1051 toolCalls = result.getResponseModel().getMessage().getToolCalls();
1052 toolCallTries++;
1053 }
1054 return result;
1055 } catch (InterruptedException e) {
1056 Thread.currentThread().interrupt();
1057 throw new OllamaException("Thread was interrupted", e);
1058 } catch (Exception e) {
1059 throw new OllamaException(e.getMessage(), e);
1060 }
1061 }
1062
1068 public void registerTool(Tools.Tool tool) {
1069 toolRegistry.addTool(tool);
1070 LOG.debug("Registered tool: {}", tool.getToolSpec().getName());
1071 }
1072
1079 public void registerTools(List<Tools.Tool> tools) {
1080 toolRegistry.addTools(tools);
1081 }
1082
1084 return toolRegistry.getRegisteredTools();
1085 }
1086
1091 public void deregisterTools() {
1092 toolRegistry.clear();
1093 LOG.debug("All tools have been deregistered.");
1094 }
1095
1105 try {
1106 Class<?> callerClass = null;
1107 try {
1108 callerClass =
1109 Class.forName(Thread.currentThread().getStackTrace()[2].getClassName());
1110 } catch (ClassNotFoundException e) {
1111 throw new OllamaException(e.getMessage(), e);
1112 }
1113
1114 OllamaToolService ollamaToolServiceAnnotation =
1115 callerClass.getDeclaredAnnotation(OllamaToolService.class);
1116 if (ollamaToolServiceAnnotation == null) {
1117 throw new IllegalStateException(
1118 callerClass + " is not annotated as " + OllamaToolService.class);
1119 }
1120
1121 Class<?>[] providers = ollamaToolServiceAnnotation.providers();
1122 for (Class<?> provider : providers) {
1123 registerAnnotatedTools(provider.getDeclaredConstructor().newInstance());
1124 }
1125 } catch (InstantiationException
1126 | NoSuchMethodException
1127 | IllegalAccessException
1128 | InvocationTargetException e) {
1129 throw new OllamaException(e.getMessage());
1130 }
1131 }
1132
1142 public void registerAnnotatedTools(Object object) {
1143 Class<?> objectClass = object.getClass();
1144 Method[] methods = objectClass.getMethods();
1145 for (Method m : methods) {
1146 ToolSpec toolSpec = m.getDeclaredAnnotation(ToolSpec.class);
1147 if (toolSpec == null) {
1148 continue;
1149 }
1150 String operationName = !toolSpec.name().isBlank() ? toolSpec.name() : m.getName();
1151 String operationDesc = !toolSpec.desc().isBlank() ? toolSpec.desc() : operationName;
1152
1153 final Map<String, Tools.Property> params = new HashMap<String, Tools.Property>() {};
1154 LinkedHashMap<String, String> methodParams = new LinkedHashMap<>();
1155 for (Parameter parameter : m.getParameters()) {
1156 final ToolProperty toolPropertyAnn =
1157 parameter.getDeclaredAnnotation(ToolProperty.class);
1158 String propType = parameter.getType().getTypeName();
1159 if (toolPropertyAnn == null) {
1160 methodParams.put(parameter.getName(), null);
1161 continue;
1162 }
1163 String propName =
1164 !toolPropertyAnn.name().isBlank()
1165 ? toolPropertyAnn.name()
1166 : parameter.getName();
1167 methodParams.put(propName, propType);
1168 params.put(
1169 propName,
1170 Tools.Property.builder()
1171 .type(propType)
1172 .description(toolPropertyAnn.desc())
1173 .required(toolPropertyAnn.required())
1174 .build());
1175 }
1176 Tools.ToolSpec toolSpecification =
1177 Tools.ToolSpec.builder()
1178 .name(operationName)
1179 .description(operationDesc)
1180 .parameters(Tools.Parameters.of(params))
1181 .build();
1182 ReflectionalToolFunction reflectionalToolFunction =
1183 new ReflectionalToolFunction(object, m, methodParams);
1184 toolRegistry.addTool(
1185 Tools.Tool.builder()
1186 .toolFunction(reflectionalToolFunction)
1187 .toolSpec(toolSpecification)
1188 .build());
1189 }
1190 }
1191
1198 public OllamaChatMessageRole addCustomRole(String roleName) {
1199 return OllamaChatMessageRole.newCustomRole(roleName);
1200 }
1201
1207 public List<OllamaChatMessageRole> listRoles() {
1209 }
1210
1218 public OllamaChatMessageRole getRole(String roleName) throws RoleNotFoundException {
1219 return OllamaChatMessageRole.getRole(roleName);
1220 }
1221
1222 public void loadMCPToolsFromJson(String mcpConfigJsonFilePath) throws IOException {
1223 String jsonContent =
1224 java.nio.file.Files.readString(java.nio.file.Paths.get(mcpConfigJsonFilePath));
1225 MCPToolsConfig config =
1226 McpJsonMapper.getDefault().readValue(jsonContent, MCPToolsConfig.class);
1227
1228 if (config.mcpServers != null && !config.mcpServers.isEmpty()) {
1229 for (Map.Entry<String, MCPToolConfig> tool : config.mcpServers.entrySet()) {
1230 ServerParameters.Builder serverParamsBuilder =
1231 ServerParameters.builder(tool.getValue().command);
1232 if (tool.getValue().args != null && !tool.getValue().args.isEmpty()) {
1233 LOG.debug(
1234 "Runnable MCP Tool command: \n\n\t{} {}\n\n",
1235 tool.getValue().command,
1236 String.join(" ", tool.getValue().args));
1237 serverParamsBuilder.args(tool.getValue().args.toArray(new String[0]));
1238 }
1239 ServerParameters serverParameters = serverParamsBuilder.build();
1240 StdioClientTransport transport =
1241 new StdioClientTransport(serverParameters, McpJsonMapper.getDefault());
1242
1243 int mcpToolRequestTimeoutSeconds = 30;
1244 try {
1245 McpSyncClient client =
1246 McpClient.sync(transport)
1247 .requestTimeout(
1248 Duration.ofSeconds(mcpToolRequestTimeoutSeconds))
1249 .build();
1250 client.initialize();
1251
1252 ListToolsResult result = client.listTools();
1253 for (io.modelcontextprotocol.spec.McpSchema.Tool mcpTool : result.tools()) {
1254 Tools.Tool mcpToolAsOllama4jTool =
1255 createOllamaToolFromMCPTool(
1256 tool.getKey(), mcpTool, serverParameters);
1257 toolRegistry.addTool(mcpToolAsOllama4jTool);
1258 }
1259 client.close();
1260 } finally {
1261 transport.close();
1262 }
1263 }
1264 }
1265 }
1266
1279 private CallToolResult callMCPTool(
1280 String mcpServerName, String toolName, Map<String, Object> arguments) {
1281 for (Tools.Tool tool : getRegisteredTools()) {
1282 if (tool.isMCPTool() && tool.getMcpServerName().equals(mcpServerName)) {
1283 if (tool.getToolSpec().getName().equals(toolName)) {
1284 ServerParameters serverParameters = tool.getMcpServerParameters();
1285 StdioClientTransport stdioTransport =
1286 new StdioClientTransport(serverParameters, McpJsonMapper.getDefault());
1287 try {
1288 LOG.info(
1289 "Calling MCP Tool: '{}.{}' with arguments: {}",
1290 mcpServerName,
1291 toolName,
1292 arguments);
1293 try (McpSyncClient client =
1294 McpClient.sync(stdioTransport)
1295 .requestTimeout(Duration.ofSeconds(requestTimeoutSeconds))
1296 .build()) {
1297 client.initialize();
1298 CallToolRequest request = new CallToolRequest(toolName, arguments);
1299 return client.callTool(request);
1300 }
1301 } finally {
1302 stdioTransport.close();
1303 }
1304 }
1305 }
1306 }
1307 throw new IllegalArgumentException(
1308 "No MCP tool found for server name: "
1309 + mcpServerName
1310 + " and tool name: "
1311 + toolName);
1312 }
1313
1314 // technical private methods //
1315
1323 private static String encodeFileToBase64(File file) throws IOException {
1324 return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
1325 }
1326
1333 private static String encodeByteArrayToBase64(byte[] bytes) {
1334 return Base64.getEncoder().encodeToString(bytes);
1335 }
1336
1350 private OllamaResult generateSyncForOllamaRequestModel(
1351 OllamaGenerateRequest ollamaRequestModel,
1352 OllamaGenerateTokenHandler thinkingStreamHandler,
1353 OllamaGenerateTokenHandler responseStreamHandler)
1354 throws OllamaException {
1355 long startTime = System.currentTimeMillis();
1356 int statusCode = -1;
1357 Object out = null;
1358 try {
1359 OllamaGenerateEndpointCaller requestCaller =
1360 new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds);
1361 OllamaResult result;
1362 if (responseStreamHandler != null) {
1363 ollamaRequestModel.setStream(true);
1364 result =
1365 requestCaller.call(
1366 ollamaRequestModel, thinkingStreamHandler, responseStreamHandler);
1367 } else {
1368 result = requestCaller.callSync(ollamaRequestModel);
1369 }
1370 statusCode = result.getHttpStatusCode();
1371 out = result;
1372 return result;
1373 } catch (InterruptedException e) {
1374 Thread.currentThread().interrupt();
1375 throw new OllamaException("Thread was interrupted", e);
1376 } catch (Exception e) {
1377 throw new OllamaException(e.getMessage(), e);
1378 } finally {
1381 ollamaRequestModel.getModel(),
1382 ollamaRequestModel.isRaw(),
1383 ollamaRequestModel.getThink(),
1384 ollamaRequestModel.isStream(),
1385 ollamaRequestModel.getOptions(),
1386 ollamaRequestModel.getFormat(),
1387 startTime,
1388 statusCode,
1389 out);
1390 }
1391 }
1392
1399 private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
1400 HttpRequest.Builder requestBuilder =
1401 HttpRequest.newBuilder(uri)
1402 .header(
1403 Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
1404 Constants.HttpConstants.APPLICATION_JSON)
1405 .timeout(Duration.ofSeconds(requestTimeoutSeconds));
1406 if (isAuthSet()) {
1407 requestBuilder.header("Authorization", auth.getAuthHeaderValue());
1408 }
1409 return requestBuilder;
1410 }
1411
1417 private boolean isAuthSet() {
1418 return auth != null;
1419 }
1420
1421 @Data
1422 @NoArgsConstructor
1423 @AllArgsConstructor
1424 public static class OllamaMCPTool {
1425 private String mcpServerName;
1426 private List<MCPToolInfo> toolInfos;
1427 private StdioClientTransport transport;
1428 }
1429
1430 @Data
1431 @NoArgsConstructor
1432 @AllArgsConstructor
1433 public static class MCPToolInfo {
1434 private String toolName;
1435 private String toolDescription;
1436 }
1437
1438 public static class MCPToolConfig {
1439 @JsonProperty("command")
1440 public String command;
1441
1442 @JsonProperty("args")
1443 public List<String> args;
1444 }
1445
1446 public static class MCPToolsConfig {
1447 @JsonProperty("mcpServers")
1448 public Map<String, MCPToolConfig> mcpServers;
1449 }
1450
1451 @Data
1452 @NoArgsConstructor
1453 @AllArgsConstructor
1454 public static class OllamaMCPToolMatchResponse {
1455 @JsonProperty("mcpServerName")
1456 public String mcpServerName;
1457
1458 @JsonProperty("toolName")
1459 public String toolName;
1460
1461 @JsonProperty("arguments")
1462 public Map<String, Object> arguments;
1463 }
1464
1475 private Tools.Tool createOllamaToolFromMCPTool(
1476 String mcpServerName,
1477 io.modelcontextprotocol.spec.McpSchema.Tool mcpTool,
1478 ServerParameters serverParameters) {
1479 Map<String, Tools.Property> properties = new java.util.HashMap<>();
1480 java.util.List<String> requiredList = new java.util.ArrayList<>();
1481
1482 if (mcpTool.inputSchema() != null && mcpTool.inputSchema().properties() != null) {
1483 // Prepare set for fast required lookup (since original is List<String>)
1484 java.util.Set<String> requiredSet = new java.util.HashSet<>();
1485 if (mcpTool.inputSchema().required() != null) {
1486 requiredSet.addAll(mcpTool.inputSchema().required());
1487 }
1488 for (Map.Entry<String, Object> entry : mcpTool.inputSchema().properties().entrySet()) {
1489 String propName = entry.getKey();
1490 Object propertyValue = entry.getValue();
1491 Map<String, Object> propertyMap = null;
1492
1493 if (propertyValue instanceof Map) {
1494 propertyMap = (Map<String, Object>) propertyValue;
1495 } else {
1496 // Defensive fallback, unexpected schema
1497 continue;
1498 }
1499
1500 // Extract standard fields; fallback to empty/defaults
1501 String type =
1502 propertyMap.get("type") != null ? propertyMap.get("type").toString() : null;
1503
1504 String description = null;
1505 if (propertyMap.get("description") != null) {
1506 description = propertyMap.get("description").toString();
1507 } else if (propertyMap.get("title") != null) {
1508 // Use 'title' as fallback for description if 'description' is missing
1509 description = propertyMap.get("title").toString();
1510 }
1511
1512 // 'required' is determined from the parent 'required' list
1513 boolean propRequired = requiredSet.contains(propName);
1514
1515 Tools.Property property =
1516 Tools.Property.builder()
1517 .type(type)
1518 .description(description)
1519 .required(propRequired)
1520 .build();
1521
1522 properties.put(propName, property);
1523 if (propRequired) {
1524 requiredList.add(propName);
1525 }
1526 }
1527 }
1528
1529 Tools.Parameters params = new Tools.Parameters();
1530 params.setProperties(properties);
1531 params.setRequired(requiredList);
1532
1533 return Tools.Tool.builder()
1534 .toolSpec(
1535 Tools.ToolSpec.builder()
1536 .name(mcpTool.name())
1537 .description(mcpTool.description())
1538 .parameters(params)
1539 .build())
1540 .toolFunction(
1541 arguments -> {
1542 CallToolResult result =
1543 this.callMCPTool(mcpServerName, mcpTool.name(), arguments);
1544 return result.toString();
1545 })
1546 .isMCPTool(true)
1547 .mcpServerName(mcpServerName)
1548 .mcpServerParameters(serverParameters)
1549 .build();
1550 }
1551}
void setBasicAuth(String username, String password)
Definition Ollama.java:144
OllamaChatMessageRole getRole(String roleName)
Definition Ollama.java:1218
OllamaEmbedResult embed(OllamaEmbedRequest modelRequest)
Definition Ollama.java:814
OllamaAsyncResultStreamer generateAsync(String model, String prompt, boolean raw, ThinkMode think)
Definition Ollama.java:948
void loadMCPToolsFromJson(String mcpConfigJsonFilePath)
Definition Ollama.java:1222
List< Tools.Tool > getRegisteredTools()
Definition Ollama.java:1083
OllamaResult generate(OllamaGenerateRequest request, OllamaGenerateStreamObserver streamObserver)
Definition Ollama.java:867
void unloadModel(String modelName)
Definition Ollama.java:745
void registerTool(Tools.Tool tool)
Definition Ollama.java:1068
void pullModel(String modelName)
Definition Ollama.java:512
OllamaChatResult chat(OllamaChatRequest request, OllamaChatTokenHandler tokenHandler)
Definition Ollama.java:985
List< Model > listModels()
Definition Ollama.java:269
List< OllamaChatMessageRole > listRoles()
Definition Ollama.java:1207
OllamaChatMessageRole addCustomRole(String roleName)
Definition Ollama.java:1198
void setBearerAuth(String bearerToken)
Definition Ollama.java:153
ModelProcessesResult ps()
Definition Ollama.java:211
void createModel(CustomModelRequest customModelRequest)
Definition Ollama.java:608
void registerAnnotatedTools(Object object)
Definition Ollama.java:1142
void registerTools(List< Tools.Tool > tools)
Definition Ollama.java:1079
void deleteModel(String modelName, boolean ignoreIfNotPresent)
Definition Ollama.java:682
ModelDetail getModelDetails(String modelName)
Definition Ollama.java:554
static void record(String endpoint, String model, boolean raw, ThinkMode thinkMode, boolean streaming, Map< String, Object > options, Object format, long startTime, int responseHttpStatus, Object response)
static OllamaChatMessageRole newCustomRole(String roleName)
static OllamaChatMessageRole getRole(String roleName)
OllamaChatResult call(OllamaChatRequest body, OllamaChatTokenHandler tokenHandler)
OllamaResult call(OllamaRequestBody body, OllamaGenerateTokenHandler thinkingStreamHandler, OllamaGenerateTokenHandler responseStreamHandler)
void addTools(List< Tools.Tool > tools)
static ObjectMapper getObjectMapper()
Definition Utils.java:32
Object apply(Map< String, Object > arguments)