Ollama4j
A Java library (wrapper/binding) for Ollama server.
Loading...
Searching...
No Matches
OllamaAPI.java
Go to the documentation of this file.
1package io.github.ollama4j;
2
3import io.github.ollama4j.exceptions.OllamaBaseException;
4import io.github.ollama4j.exceptions.RoleNotFoundException;
5import io.github.ollama4j.exceptions.ToolInvocationException;
6import io.github.ollama4j.exceptions.ToolNotFoundException;
7import io.github.ollama4j.models.chat.*;
8import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
9import io.github.ollama4j.models.embeddings.OllamaEmbeddingResponseModel;
10import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
11import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
12import io.github.ollama4j.models.generate.OllamaGenerateRequest;
13import io.github.ollama4j.models.generate.OllamaStreamHandler;
14import io.github.ollama4j.models.generate.OllamaTokenHandler;
15import io.github.ollama4j.models.ps.ModelsProcessResponse;
16import io.github.ollama4j.models.request.*;
17import io.github.ollama4j.models.response.*;
18import io.github.ollama4j.tools.*;
19import io.github.ollama4j.tools.annotations.OllamaToolService;
20import io.github.ollama4j.tools.annotations.ToolProperty;
21import io.github.ollama4j.tools.annotations.ToolSpec;
22import io.github.ollama4j.utils.Options;
23import io.github.ollama4j.utils.Utils;
24import lombok.Setter;
25
26import java.io.*;
27import java.lang.reflect.InvocationTargetException;
28import java.lang.reflect.Method;
29import java.lang.reflect.Parameter;
30import java.net.URI;
31import java.net.URISyntaxException;
32import java.net.http.HttpClient;
33import java.net.http.HttpConnectTimeoutException;
34import java.net.http.HttpRequest;
35import java.net.http.HttpResponse;
36import java.nio.charset.StandardCharsets;
37import java.nio.file.Files;
38import java.time.Duration;
39import java.util.*;
40import java.util.stream.Collectors;
41
42import org.slf4j.Logger;
43import org.slf4j.LoggerFactory;
44import org.jsoup.Jsoup;
45import org.jsoup.nodes.Document;
46import org.jsoup.nodes.Element;
47import org.jsoup.select.Elements;
48
52@SuppressWarnings({"DuplicatedCode", "resource"})
53public class OllamaAPI {
54
55 private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class);
56 private final String host;
61 @Setter
62 private long requestTimeoutSeconds = 10;
67 @Setter
68 private boolean verbose = true;
69
70 @Setter
71 private int maxChatToolCallRetries = 3;
72
73 private BasicAuth basicAuth;
74
75 private final ToolRegistry toolRegistry = new ToolRegistry();
76
80 public OllamaAPI() {
81 this.host = "http://localhost:11434";
82 }
83
89 public OllamaAPI(String host) {
90 if (host.endsWith("/")) {
91 this.host = host.substring(0, host.length() - 1);
92 } else {
93 this.host = host;
94 }
95 }
96
103 public void setBasicAuth(String username, String password) {
104 this.basicAuth = new BasicAuth(username, password);
105 }
106
112 public boolean ping() {
113 String url = this.host + "/api/tags";
114 HttpClient httpClient = HttpClient.newHttpClient();
115 HttpRequest httpRequest = null;
116 try {
117 httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
118 } catch (URISyntaxException e) {
119 throw new RuntimeException(e);
120 }
121 HttpResponse<String> response = null;
122 try {
123 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
124 } catch (HttpConnectTimeoutException e) {
125 return false;
126 } catch (IOException | InterruptedException e) {
127 throw new RuntimeException(e);
128 }
129 int statusCode = response.statusCode();
130 return statusCode == 200;
131 }
132
141 public ModelsProcessResponse ps() throws IOException, InterruptedException, OllamaBaseException {
142 String url = this.host + "/api/ps";
143 HttpClient httpClient = HttpClient.newHttpClient();
144 HttpRequest httpRequest = null;
145 try {
146 httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
147 } catch (URISyntaxException e) {
148 throw new RuntimeException(e);
149 }
150 HttpResponse<String> response = null;
151 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
152 int statusCode = response.statusCode();
153 String responseString = response.body();
154 if (statusCode == 200) {
155 return Utils.getObjectMapper().readValue(responseString, ModelsProcessResponse.class);
156 } else {
157 throw new OllamaBaseException(statusCode + " - " + responseString);
158 }
159 }
160
170 public List<Model> listModels() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
171 String url = this.host + "/api/tags";
172 HttpClient httpClient = HttpClient.newHttpClient();
173 HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
174 HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
175 int statusCode = response.statusCode();
176 String responseString = response.body();
177 if (statusCode == 200) {
178 return Utils.getObjectMapper().readValue(responseString, ListModelsResponse.class).getModels();
179 } else {
180 throw new OllamaBaseException(statusCode + " - " + responseString);
181 }
182 }
183
194 public List<LibraryModel> listModelsFromLibrary() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
195 String url = "https://ollama.com/library";
196 HttpClient httpClient = HttpClient.newHttpClient();
197 HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
198 HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
199 int statusCode = response.statusCode();
200 String responseString = response.body();
201 List<LibraryModel> models = new ArrayList<>();
202 if (statusCode == 200) {
203 Document doc = Jsoup.parse(responseString);
204 Elements modelSections = doc.selectXpath("//*[@id='repo']/ul/li/a");
205 for (Element e : modelSections) {
206 LibraryModel model = new LibraryModel();
207 Elements names = e.select("div > h2 > div > span");
208 Elements desc = e.select("div > p");
209 Elements pullCounts = e.select("div:nth-of-type(2) > p > span:first-of-type > span:first-of-type");
210 Elements popularTags = e.select("div > div > span");
211 Elements totalTags = e.select("div:nth-of-type(2) > p > span:nth-of-type(2) > span:first-of-type");
212 Elements lastUpdatedTime = e.select("div:nth-of-type(2) > p > span:nth-of-type(3) > span:nth-of-type(2)");
213
214 if (names.first() == null || names.isEmpty()) {
215 // if name cannot be extracted, skip.
216 continue;
217 }
218 Optional.ofNullable(names.first()).map(Element::text).ifPresent(model::setName);
219 model.setDescription(Optional.ofNullable(desc.first()).map(Element::text).orElse(""));
220 model.setPopularTags(Optional.of(popularTags).map(tags -> tags.stream().map(Element::text).collect(Collectors.toList())).orElse(new ArrayList<>()));
221 model.setPullCount(Optional.ofNullable(pullCounts.first()).map(Element::text).orElse(""));
222 model.setTotalTags(Optional.ofNullable(totalTags.first()).map(Element::text).map(Integer::parseInt).orElse(0));
223 model.setLastUpdated(Optional.ofNullable(lastUpdatedTime.first()).map(Element::text).orElse(""));
224
225 models.add(model);
226 }
227 return models;
228 } else {
229 throw new OllamaBaseException(statusCode + " - " + responseString);
230 }
231 }
232
247 public LibraryModelDetail getLibraryModelDetails(LibraryModel libraryModel) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
248 String url = String.format("https://ollama.com/library/%s/tags", libraryModel.getName());
249 HttpClient httpClient = HttpClient.newHttpClient();
250 HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
251 HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
252 int statusCode = response.statusCode();
253 String responseString = response.body();
254
255 List<LibraryModelTag> libraryModelTags = new ArrayList<>();
256 if (statusCode == 200) {
257 Document doc = Jsoup.parse(responseString);
258 Elements tagSections = doc.select("html > body > main > div > section > div > div > div:nth-child(n+2) > div");
259 for (Element e : tagSections) {
260 Elements tags = e.select("div > a > div");
261 Elements tagsMetas = e.select("div > span");
262
263 LibraryModelTag libraryModelTag = new LibraryModelTag();
264
265 if (tags.first() == null || tags.isEmpty()) {
266 // if tag cannot be extracted, skip.
267 continue;
268 }
269 libraryModelTag.setName(libraryModel.getName());
270 Optional.ofNullable(tags.first()).map(Element::text).ifPresent(libraryModelTag::setTag);
271 libraryModelTag.setSize(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("•")).filter(parts -> parts.length > 1).map(parts -> parts[1].trim()).orElse(""));
272 libraryModelTag.setLastUpdated(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("•")).filter(parts -> parts.length > 1).map(parts -> parts[2].trim()).orElse(""));
273 libraryModelTags.add(libraryModelTag);
274 }
275 LibraryModelDetail libraryModelDetail = new LibraryModelDetail();
276 libraryModelDetail.setModel(libraryModel);
277 libraryModelDetail.setTags(libraryModelTags);
278 return libraryModelDetail;
279 } else {
280 throw new OllamaBaseException(statusCode + " - " + responseString);
281 }
282 }
283
300 public LibraryModelTag findModelTagFromLibrary(String modelName, String tag) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
301 List<LibraryModel> libraryModels = this.listModelsFromLibrary();
302 LibraryModel libraryModel = libraryModels.stream().filter(model -> model.getName().equals(modelName)).findFirst().orElseThrow(() -> new NoSuchElementException(String.format("Model by name '%s' not found", modelName)));
303 LibraryModelDetail libraryModelDetail = this.getLibraryModelDetails(libraryModel);
304 LibraryModelTag libraryModelTag = libraryModelDetail.getTags().stream().filter(tagName -> tagName.getTag().equals(tag)).findFirst().orElseThrow(() -> new NoSuchElementException(String.format("Tag '%s' for model '%s' not found", tag, modelName)));
305 return libraryModelTag;
306 }
307
318 public void pullModel(String modelName) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
319 String url = this.host + "/api/pull";
320 String jsonData = new ModelRequest(modelName).toString();
321 HttpRequest request = getRequestBuilderDefault(new URI(url)).POST(HttpRequest.BodyPublishers.ofString(jsonData)).header("Accept", "application/json").header("Content-type", "application/json").build();
322 HttpClient client = HttpClient.newHttpClient();
323 HttpResponse<InputStream> response = client.send(request, HttpResponse.BodyHandlers.ofInputStream());
324 int statusCode = response.statusCode();
325 InputStream responseBodyStream = response.body();
326 String responseString = "";
327 try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
328 String line;
329 while ((line = reader.readLine()) != null) {
330 ModelPullResponse modelPullResponse = Utils.getObjectMapper().readValue(line, ModelPullResponse.class);
331 if (verbose) {
332 logger.info(modelPullResponse.getStatus());
333 }
334 }
335 }
336 if (statusCode != 200) {
337 throw new OllamaBaseException(statusCode + " - " + responseString);
338 }
339 }
340
353 public void pullModel(LibraryModelTag libraryModelTag) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
354 String tagToPull = String.format("%s:%s", libraryModelTag.getName(), libraryModelTag.getTag());
355 pullModel(tagToPull);
356 }
357
368 public ModelDetail getModelDetails(String modelName) throws IOException, OllamaBaseException, InterruptedException, URISyntaxException {
369 String url = this.host + "/api/show";
370 String jsonData = new ModelRequest(modelName).toString();
371 HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
372 HttpClient client = HttpClient.newHttpClient();
373 HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
374 int statusCode = response.statusCode();
375 String responseBody = response.body();
376 if (statusCode == 200) {
377 return Utils.getObjectMapper().readValue(responseBody, ModelDetail.class);
378 } else {
379 throw new OllamaBaseException(statusCode + " - " + responseBody);
380 }
381 }
382
394 public void createModelWithFilePath(String modelName, String modelFilePath) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
395 String url = this.host + "/api/create";
396 String jsonData = new CustomModelFilePathRequest(modelName, modelFilePath).toString();
397 HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
398 HttpClient client = HttpClient.newHttpClient();
399 HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
400 int statusCode = response.statusCode();
401 String responseString = response.body();
402 if (statusCode != 200) {
403 throw new OllamaBaseException(statusCode + " - " + responseString);
404 }
405 // FIXME: Ollama API returns HTTP status code 200 for model creation failure cases. Correct this
406 // if the issue is fixed in the Ollama API server.
407 if (responseString.contains("error")) {
408 throw new OllamaBaseException(responseString);
409 }
410 if (verbose) {
411 logger.info(responseString);
412 }
413 }
414
426 public void createModelWithModelFileContents(String modelName, String modelFileContents) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
427 String url = this.host + "/api/create";
428 String jsonData = new CustomModelFileContentsRequest(modelName, modelFileContents).toString();
429 HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
430 HttpClient client = HttpClient.newHttpClient();
431 HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
432 int statusCode = response.statusCode();
433 String responseString = response.body();
434 if (statusCode != 200) {
435 throw new OllamaBaseException(statusCode + " - " + responseString);
436 }
437 if (responseString.contains("error")) {
438 throw new OllamaBaseException(responseString);
439 }
440 if (verbose) {
441 logger.info(responseString);
442 }
443 }
444
455 public void deleteModel(String modelName, boolean ignoreIfNotPresent) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
456 String url = this.host + "/api/delete";
457 String jsonData = new ModelRequest(modelName).toString();
458 HttpRequest request = getRequestBuilderDefault(new URI(url)).method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).header("Accept", "application/json").header("Content-type", "application/json").build();
459 HttpClient client = HttpClient.newHttpClient();
460 HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
461 int statusCode = response.statusCode();
462 String responseBody = response.body();
463 if (statusCode == 404 && responseBody.contains("model") && responseBody.contains("not found")) {
464 return;
465 }
466 if (statusCode != 200) {
467 throw new OllamaBaseException(statusCode + " - " + responseBody);
468 }
469 }
470
482 @Deprecated
483 public List<Double> generateEmbeddings(String model, String prompt) throws IOException, InterruptedException, OllamaBaseException {
484 return generateEmbeddings(new OllamaEmbeddingsRequestModel(model, prompt));
485 }
486
497 @Deprecated
498 public List<Double> generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException {
499 URI uri = URI.create(this.host + "/api/embeddings");
500 String jsonData = modelRequest.toString();
501 HttpClient httpClient = HttpClient.newHttpClient();
502 HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).header("Accept", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData));
503 HttpRequest request = requestBuilder.build();
504 HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
505 int statusCode = response.statusCode();
506 String responseBody = response.body();
507 if (statusCode == 200) {
508 OllamaEmbeddingResponseModel embeddingResponse = Utils.getObjectMapper().readValue(responseBody, OllamaEmbeddingResponseModel.class);
509 return embeddingResponse.getEmbedding();
510 } else {
511 throw new OllamaBaseException(statusCode + " - " + responseBody);
512 }
513 }
514
525 public OllamaEmbedResponseModel embed(String model, List<String> inputs) throws IOException, InterruptedException, OllamaBaseException {
526 return embed(new OllamaEmbedRequestModel(model, inputs));
527 }
528
538 public OllamaEmbedResponseModel embed(OllamaEmbedRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException {
539 URI uri = URI.create(this.host + "/api/embed");
540 String jsonData = Utils.getObjectMapper().writeValueAsString(modelRequest);
541 HttpClient httpClient = HttpClient.newHttpClient();
542
543 HttpRequest request = HttpRequest.newBuilder(uri).header("Accept", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
544
545 HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
546 int statusCode = response.statusCode();
547 String responseBody = response.body();
548
549 if (statusCode == 200) {
550 return Utils.getObjectMapper().readValue(responseBody, OllamaEmbedResponseModel.class);
551 } else {
552 throw new OllamaBaseException(statusCode + " - " + responseBody);
553 }
554 }
555
571 public OllamaResult generate(String model, String prompt, boolean raw, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
572 OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
573 ollamaRequestModel.setRaw(raw);
574 ollamaRequestModel.setOptions(options.getOptionsMap());
575 return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
576 }
577
592 public OllamaResult generate(String model, String prompt, boolean raw, Options options) throws OllamaBaseException, IOException, InterruptedException {
593 return generate(model, prompt, raw, options, null);
594 }
595
608 public OllamaToolsResult generateWithTools(String model, String prompt, Options options) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
609 boolean raw = true;
610 OllamaToolsResult toolResult = new OllamaToolsResult();
611 Map<ToolFunctionCallSpec, Object> toolResults = new HashMap<>();
612
613 if(!prompt.startsWith("[AVAILABLE_TOOLS]")){
614 final Tools.PromptBuilder promptBuilder = new Tools.PromptBuilder();
615 for(Tools.ToolSpecification spec : toolRegistry.getRegisteredSpecs()) {
616 promptBuilder.withToolSpecification(spec);
617 }
618 promptBuilder.withPrompt(prompt);
619 prompt = promptBuilder.build();
620 }
621
622 OllamaResult result = generate(model, prompt, raw, options, null);
623 toolResult.setModelResult(result);
624
625 String toolsResponse = result.getResponse();
626 if (toolsResponse.contains("[TOOL_CALLS]")) {
627 toolsResponse = toolsResponse.replace("[TOOL_CALLS]", "");
628 }
629
630 List<ToolFunctionCallSpec> toolFunctionCallSpecs = Utils.getObjectMapper().readValue(toolsResponse, Utils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class));
631 for (ToolFunctionCallSpec toolFunctionCallSpec : toolFunctionCallSpecs) {
632 toolResults.put(toolFunctionCallSpec, invokeTool(toolFunctionCallSpec));
633 }
634 toolResult.setToolResults(toolResults);
635 return toolResult;
636 }
637
647 public OllamaAsyncResultStreamer generateAsync(String model, String prompt, boolean raw) {
648 OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
649 ollamaRequestModel.setRaw(raw);
650 URI uri = URI.create(this.host + "/api/generate");
651 OllamaAsyncResultStreamer ollamaAsyncResultStreamer = new OllamaAsyncResultStreamer(getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds);
652 ollamaAsyncResultStreamer.start();
653 return ollamaAsyncResultStreamer;
654 }
655
672 public OllamaResult generateWithImageFiles(String model, String prompt, List<File> imageFiles, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
673 List<String> images = new ArrayList<>();
674 for (File imageFile : imageFiles) {
675 images.add(encodeFileToBase64(imageFile));
676 }
677 OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt, images);
678 ollamaRequestModel.setOptions(options.getOptionsMap());
679 return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
680 }
681
691 public OllamaResult generateWithImageFiles(String model, String prompt, List<File> imageFiles, Options options) throws OllamaBaseException, IOException, InterruptedException {
692 return generateWithImageFiles(model, prompt, imageFiles, options, null);
693 }
694
712 public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
713 List<String> images = new ArrayList<>();
714 for (String imageURL : imageURLs) {
715 images.add(encodeByteArrayToBase64(Utils.loadImageBytesFromUrl(imageURL)));
716 }
717 OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt, images);
718 ollamaRequestModel.setOptions(options.getOptionsMap());
719 return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
720 }
721
732 public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs, Options options) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
733 return generateWithImageURLs(model, prompt, imageURLs, options, null);
734 }
735
750 public OllamaChatResult chat(String model, List<OllamaChatMessage> messages) throws OllamaBaseException, IOException, InterruptedException {
752 return chat(builder.withMessages(messages).build());
753 }
754
769 public OllamaChatResult chat(OllamaChatRequest request) throws OllamaBaseException, IOException, InterruptedException {
770 return chat(request, null);
771 }
772
788 public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
789 return chatStreaming(request, new OllamaChatStreamObserver(streamHandler));
790 }
791
807 public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) throws OllamaBaseException, IOException, InterruptedException {
808 OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
809 OllamaChatResult result;
810
811 // add all registered tools to Request
812 request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
813
814 if (tokenHandler != null) {
815 request.setStream(true);
816 result = requestCaller.call(request, tokenHandler);
817 } else {
818 result = requestCaller.callSync(request);
819 }
820
821 // check if toolCallIsWanted
822 List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls();
823 int toolCallTries = 0;
824 while(toolCalls != null && !toolCalls.isEmpty() && toolCallTries < maxChatToolCallRetries){
825 for (OllamaChatToolCalls toolCall : toolCalls){
826 String toolName = toolCall.getFunction().getName();
827 ToolFunction toolFunction = toolRegistry.getToolFunction(toolName);
828 Map<String, Object> arguments = toolCall.getFunction().getArguments();
829 Object res = toolFunction.apply(arguments);
830 request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]"));
831 }
832
833 if (tokenHandler != null) {
834 result = requestCaller.call(request, tokenHandler);
835 } else {
836 result = requestCaller.callSync(request);
837 }
838 toolCalls = result.getResponseModel().getMessage().getToolCalls();
839 toolCallTries++;
840 }
841
842 return result;
843 }
844
845 public void registerTool(Tools.ToolSpecification toolSpecification) {
846 toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
847 }
848
849
851 try {
852 Class<?> callerClass = null;
853 try {
854 callerClass = Class.forName(Thread.currentThread().getStackTrace()[2].getClassName());
855 } catch (ClassNotFoundException e) {
856 throw new RuntimeException(e);
857 }
858
859 OllamaToolService ollamaToolServiceAnnotation = callerClass.getDeclaredAnnotation(OllamaToolService.class);
860 if (ollamaToolServiceAnnotation == null) {
861 throw new IllegalStateException(callerClass + " is not annotated as " + OllamaToolService.class);
862 }
863
864 Class<?>[] providers = ollamaToolServiceAnnotation.providers();
865 for (Class<?> provider : providers) {
866 registerAnnotatedTools(provider.getDeclaredConstructor().newInstance());
867 }
868 } catch (InstantiationException | NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
869 throw new RuntimeException(e);
870 }
871 }
872
873 public void registerAnnotatedTools(Object object) {
874 Class<?> objectClass = object.getClass();
875 Method[] methods = objectClass.getMethods();
876 for(Method m : methods) {
877 ToolSpec toolSpec = m.getDeclaredAnnotation(ToolSpec.class);
878 if(toolSpec == null){
879 continue;
880 }
881 String operationName = !toolSpec.name().isBlank() ? toolSpec.name() : m.getName();
882 String operationDesc = !toolSpec.desc().isBlank() ? toolSpec.desc() : operationName;
883
884 final Tools.PropsBuilder propsBuilder = new Tools.PropsBuilder();
885 LinkedHashMap<String,String> methodParams = new LinkedHashMap<>();
886 for (Parameter parameter : m.getParameters()) {
887 final ToolProperty toolPropertyAnn = parameter.getDeclaredAnnotation(ToolProperty.class);
888 String propType = parameter.getType().getTypeName();
889 if(toolPropertyAnn == null) {
890 methodParams.put(parameter.getName(),null);
891 continue;
892 }
893 String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName();
894 methodParams.put(propName,propType);
895 propsBuilder.withProperty(propName,Tools.PromptFuncDefinition.Property.builder()
896 .type(propType)
897 .description(toolPropertyAnn.desc())
898 .required(toolPropertyAnn.required())
899 .build());
900 }
901 final Map<String, Tools.PromptFuncDefinition.Property> params = propsBuilder.build();
902 List<String> reqProps = params.entrySet().stream()
903 .filter(e -> e.getValue().isRequired())
904 .map(Map.Entry::getKey)
905 .collect(Collectors.toList());
906
907 Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder()
908 .functionName(operationName)
909 .functionDescription(operationDesc)
910 .toolPrompt(
911 Tools.PromptFuncDefinition.builder().type("function").function(
913 .name(operationName)
914 .description(operationDesc)
915 .parameters(
917 .type("object")
918 .properties(
919 params
920 )
921 .required(reqProps)
922 .build()
923 ).build()
924 ).build()
925 )
926 .build();
927
928 ReflectionalToolFunction reflectionalToolFunction =
929 new ReflectionalToolFunction(object, m, methodParams);
930 toolSpecification.setToolFunction(reflectionalToolFunction);
931 toolRegistry.addTool(toolSpecification.getFunctionName(),toolSpecification);
932 }
933
934 }
935
942 public OllamaChatMessageRole addCustomRole(String roleName) {
943 return OllamaChatMessageRole.newCustomRole(roleName);
944 }
945
951 public List<OllamaChatMessageRole> listRoles() {
953 }
954
962 public OllamaChatMessageRole getRole(String roleName) throws RoleNotFoundException {
963 return OllamaChatMessageRole.getRole(roleName);
964 }
965
966
967 // technical private methods //
968
969 private static String encodeFileToBase64(File file) throws IOException {
970 return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
971 }
972
973 private static String encodeByteArrayToBase64(byte[] bytes) {
974 return Base64.getEncoder().encodeToString(bytes);
975 }
976
977 private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
978 OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
979 OllamaResult result;
980 if (streamHandler != null) {
981 ollamaRequestModel.setStream(true);
982 result = requestCaller.call(ollamaRequestModel, streamHandler);
983 } else {
984 result = requestCaller.callSync(ollamaRequestModel);
985 }
986 return result;
987 }
988
995 private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
996 HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header("Content-Type", "application/json").timeout(Duration.ofSeconds(requestTimeoutSeconds));
997 if (isBasicAuthCredentialsSet()) {
998 requestBuilder.header("Authorization", getBasicAuthHeaderValue());
999 }
1000 return requestBuilder;
1001 }
1002
1008 private String getBasicAuthHeaderValue() {
1009 String credentialsToEncode = basicAuth.getUsername() + ":" + basicAuth.getPassword();
1010 return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
1011 }
1012
1018 private boolean isBasicAuthCredentialsSet() {
1019 return basicAuth != null;
1020 }
1021
1022 private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec) throws ToolInvocationException {
1023 try {
1024 String methodName = toolFunctionCallSpec.getName();
1025 Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
1026 ToolFunction function = toolRegistry.getToolFunction(methodName);
1027 if (verbose) {
1028 logger.debug("Invoking function {} with arguments {}", methodName, arguments);
1029 }
1030 if (function == null) {
1031 throw new ToolNotFoundException("No such tool: " + methodName);
1032 }
1033 return function.apply(arguments);
1034 } catch (Exception e) {
1035 throw new ToolInvocationException("Failed to invoke tool: " + toolFunctionCallSpec.getName(), e);
1036 }
1037 }
1038}
OllamaResult generateWithImageFiles(String model, String prompt, List< File > imageFiles, Options options, OllamaStreamHandler streamHandler)
OllamaChatResult chat(OllamaChatRequest request)
ModelDetail getModelDetails(String modelName)
List< Double > generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest)
OllamaAsyncResultStreamer generateAsync(String model, String prompt, boolean raw)
OllamaChatResult chat(String model, List< OllamaChatMessage > messages)
List< Double > generateEmbeddings(String model, String prompt)
void pullModel(String modelName)
void setBasicAuth(String username, String password)
OllamaResult generate(String model, String prompt, boolean raw, Options options, OllamaStreamHandler streamHandler)
OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler)
void deleteModel(String modelName, boolean ignoreIfNotPresent)
void createModelWithFilePath(String modelName, String modelFilePath)
OllamaResult generate(String model, String prompt, boolean raw, Options options)
ModelsProcessResponse ps()
LibraryModelTag findModelTagFromLibrary(String modelName, String tag)
OllamaChatMessageRole getRole(String roleName)
OllamaChatMessageRole addCustomRole(String roleName)
void pullModel(LibraryModelTag libraryModelTag)
List< OllamaChatMessageRole > listRoles()
LibraryModelDetail getLibraryModelDetails(LibraryModel libraryModel)
OllamaResult generateWithImageURLs(String model, String prompt, List< String > imageURLs, Options options, OllamaStreamHandler streamHandler)
OllamaEmbedResponseModel embed(OllamaEmbedRequestModel modelRequest)
void registerTool(Tools.ToolSpecification toolSpecification)
void registerAnnotatedTools(Object object)
List< LibraryModel > listModelsFromLibrary()
OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler)
OllamaToolsResult generateWithTools(String model, String prompt, Options options)
OllamaResult generateWithImageURLs(String model, String prompt, List< String > imageURLs, Options options)
OllamaEmbedResponseModel embed(String model, List< String > inputs)
void createModelWithModelFileContents(String modelName, String modelFileContents)
OllamaResult generateWithImageFiles(String model, String prompt, List< File > imageFiles, Options options)
static OllamaChatMessageRole newCustomRole(String roleName)
static OllamaChatMessageRole getRole(String roleName)
OllamaChatRequestBuilder withMessages(List< OllamaChatMessage > messages)
static OllamaChatRequestBuilder getInstance(String model)
OllamaChatResult call(OllamaChatRequest body, OllamaTokenHandler tokenHandler)
OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
ToolFunction getToolFunction(String name)
static byte[] loadImageBytesFromUrl(String imageUrl)
Definition Utils.java:25
static ObjectMapper getObjectMapper()
Definition Utils.java:17
Object apply(Map< String, Object > arguments)