Ollama4j
A Java library (wrapper/binding) for Ollama server.
Loading...
Searching...
No Matches
OllamaAPI.java
Go to the documentation of this file.
1package io.github.ollama4j;
2
3import com.fasterxml.jackson.core.JsonParseException;
4import com.fasterxml.jackson.databind.JsonNode;
5import com.fasterxml.jackson.databind.ObjectMapper;
6import io.github.ollama4j.exceptions.OllamaBaseException;
7import io.github.ollama4j.exceptions.RoleNotFoundException;
8import io.github.ollama4j.exceptions.ToolInvocationException;
9import io.github.ollama4j.exceptions.ToolNotFoundException;
10import io.github.ollama4j.models.chat.*;
11import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
12import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
13import io.github.ollama4j.models.embeddings.OllamaEmbeddingResponseModel;
14import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
15import io.github.ollama4j.models.generate.OllamaGenerateRequest;
16import io.github.ollama4j.models.generate.OllamaStreamHandler;
17import io.github.ollama4j.models.generate.OllamaTokenHandler;
18import io.github.ollama4j.models.ps.ModelsProcessResponse;
19import io.github.ollama4j.models.request.*;
20import io.github.ollama4j.models.response.*;
21import io.github.ollama4j.tools.*;
22import io.github.ollama4j.tools.annotations.OllamaToolService;
23import io.github.ollama4j.tools.annotations.ToolProperty;
24import io.github.ollama4j.tools.annotations.ToolSpec;
25import io.github.ollama4j.utils.Options;
26import io.github.ollama4j.utils.Utils;
27import lombok.Setter;
28import org.jsoup.Jsoup;
29import org.jsoup.nodes.Document;
30import org.jsoup.nodes.Element;
31import org.jsoup.select.Elements;
32import org.slf4j.Logger;
33import org.slf4j.LoggerFactory;
34
35import java.io.*;
36import java.lang.reflect.InvocationTargetException;
37import java.lang.reflect.Method;
38import java.lang.reflect.Parameter;
39import java.net.URI;
40import java.net.URISyntaxException;
41import java.net.http.HttpClient;
42import java.net.http.HttpConnectTimeoutException;
43import java.net.http.HttpRequest;
44import java.net.http.HttpResponse;
45import java.nio.charset.StandardCharsets;
46import java.nio.file.Files;
47import java.time.Duration;
48import java.util.*;
49import java.util.stream.Collectors;
50
54@SuppressWarnings({ "DuplicatedCode", "resource" })
55public class OllamaAPI {
56
57 private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class);
58 private final String host;
63 @Setter
64 private long requestTimeoutSeconds = 10;
69 @Setter
70 private boolean verbose = true;
71
72 @Setter
73 private int maxChatToolCallRetries = 3;
74
75 private Auth auth;
76
77 private int numberOfRetriesForModelPull = 0;
78
79 public void setNumberOfRetriesForModelPull(int numberOfRetriesForModelPull) {
80 this.numberOfRetriesForModelPull = numberOfRetriesForModelPull;
81 }
82
83 private final ToolRegistry toolRegistry = new ToolRegistry();
84
89 public OllamaAPI() {
90 this.host = "http://localhost:11434";
91 }
92
98 public OllamaAPI(String host) {
99 if (host.endsWith("/")) {
100 this.host = host.substring(0, host.length() - 1);
101 } else {
102 this.host = host;
103 }
104 if (this.verbose) {
105 logger.info("Ollama API initialized with host: " + this.host);
106 }
107 }
108
116 public void setBasicAuth(String username, String password) {
117 this.auth = new BasicAuth(username, password);
118 }
119
126 public void setBearerAuth(String bearerToken) {
127 this.auth = new BearerAuth(bearerToken);
128 }
129
135 public boolean ping() {
136 String url = this.host + "/api/tags";
137 HttpClient httpClient = HttpClient.newHttpClient();
138 HttpRequest httpRequest = null;
139 try {
140 httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
141 .header("Content-type", "application/json").GET().build();
142 } catch (URISyntaxException e) {
143 throw new RuntimeException(e);
144 }
145 HttpResponse<String> response = null;
146 try {
147 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
148 } catch (HttpConnectTimeoutException e) {
149 return false;
150 } catch (IOException | InterruptedException e) {
151 throw new RuntimeException(e);
152 }
153 int statusCode = response.statusCode();
154 return statusCode == 200;
155 }
156
166 public ModelsProcessResponse ps() throws IOException, InterruptedException, OllamaBaseException {
167 String url = this.host + "/api/ps";
168 HttpClient httpClient = HttpClient.newHttpClient();
169 HttpRequest httpRequest = null;
170 try {
171 httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
172 .header("Content-type", "application/json").GET().build();
173 } catch (URISyntaxException e) {
174 throw new RuntimeException(e);
175 }
176 HttpResponse<String> response = null;
177 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
178 int statusCode = response.statusCode();
179 String responseString = response.body();
180 if (statusCode == 200) {
181 return Utils.getObjectMapper().readValue(responseString, ModelsProcessResponse.class);
182 } else {
183 throw new OllamaBaseException(statusCode + " - " + responseString);
184 }
185 }
186
196 public List<Model> listModels() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
197 String url = this.host + "/api/tags";
198 HttpClient httpClient = HttpClient.newHttpClient();
199 HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
200 .header("Content-type", "application/json").GET().build();
201 HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
202 int statusCode = response.statusCode();
203 String responseString = response.body();
204 if (statusCode == 200) {
205 return Utils.getObjectMapper().readValue(responseString, ListModelsResponse.class).getModels();
206 } else {
207 throw new OllamaBaseException(statusCode + " - " + responseString);
208 }
209 }
210
228 public List<LibraryModel> listModelsFromLibrary()
229 throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
230 String url = "https://ollama.com/library";
231 HttpClient httpClient = HttpClient.newHttpClient();
232 HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
233 .header("Content-type", "application/json").GET().build();
234 HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
235 int statusCode = response.statusCode();
236 String responseString = response.body();
237 List<LibraryModel> models = new ArrayList<>();
238 if (statusCode == 200) {
239 Document doc = Jsoup.parse(responseString);
240 Elements modelSections = doc.selectXpath("//*[@id='repo']/ul/li/a");
241 for (Element e : modelSections) {
242 LibraryModel model = new LibraryModel();
243 Elements names = e.select("div > h2 > div > span");
244 Elements desc = e.select("div > p");
245 Elements pullCounts = e.select("div:nth-of-type(2) > p > span:first-of-type > span:first-of-type");
246 Elements popularTags = e.select("div > div > span");
247 Elements totalTags = e.select("div:nth-of-type(2) > p > span:nth-of-type(2) > span:first-of-type");
248 Elements lastUpdatedTime = e
249 .select("div:nth-of-type(2) > p > span:nth-of-type(3) > span:nth-of-type(2)");
250
251 if (names.first() == null || names.isEmpty()) {
252 // if name cannot be extracted, skip.
253 continue;
254 }
255 Optional.ofNullable(names.first()).map(Element::text).ifPresent(model::setName);
256 model.setDescription(Optional.ofNullable(desc.first()).map(Element::text).orElse(""));
257 model.setPopularTags(Optional.of(popularTags)
258 .map(tags -> tags.stream().map(Element::text).collect(Collectors.toList()))
259 .orElse(new ArrayList<>()));
260 model.setPullCount(Optional.ofNullable(pullCounts.first()).map(Element::text).orElse(""));
261 model.setTotalTags(
262 Optional.ofNullable(totalTags.first()).map(Element::text).map(Integer::parseInt).orElse(0));
263 model.setLastUpdated(Optional.ofNullable(lastUpdatedTime.first()).map(Element::text).orElse(""));
264
265 models.add(model);
266 }
267 return models;
268 } else {
269 throw new OllamaBaseException(statusCode + " - " + responseString);
270 }
271 }
272
296 throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
297 String url = String.format("https://ollama.com/library/%s/tags", libraryModel.getName());
298 HttpClient httpClient = HttpClient.newHttpClient();
299 HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
300 .header("Content-type", "application/json").GET().build();
301 HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
302 int statusCode = response.statusCode();
303 String responseString = response.body();
304
305 List<LibraryModelTag> libraryModelTags = new ArrayList<>();
306 if (statusCode == 200) {
307 Document doc = Jsoup.parse(responseString);
308 Elements tagSections = doc
309 .select("html > body > main > div > section > div > div > div:nth-child(n+2) > div");
310 for (Element e : tagSections) {
311 Elements tags = e.select("div > a > div");
312 Elements tagsMetas = e.select("div > span");
313
314 LibraryModelTag libraryModelTag = new LibraryModelTag();
315
316 if (tags.first() == null || tags.isEmpty()) {
317 // if tag cannot be extracted, skip.
318 continue;
319 }
320 libraryModelTag.setName(libraryModel.getName());
321 Optional.ofNullable(tags.first()).map(Element::text).ifPresent(libraryModelTag::setTag);
322 libraryModelTag.setSize(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("•"))
323 .filter(parts -> parts.length > 1).map(parts -> parts[1].trim()).orElse(""));
324 libraryModelTag
325 .setLastUpdated(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("•"))
326 .filter(parts -> parts.length > 1).map(parts -> parts[2].trim()).orElse(""));
327 libraryModelTags.add(libraryModelTag);
328 }
329 LibraryModelDetail libraryModelDetail = new LibraryModelDetail();
330 libraryModelDetail.setModel(libraryModel);
331 libraryModelDetail.setTags(libraryModelTags);
332 return libraryModelDetail;
333 } else {
334 throw new OllamaBaseException(statusCode + " - " + responseString);
335 }
336 }
337
359 public LibraryModelTag findModelTagFromLibrary(String modelName, String tag)
360 throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
361 List<LibraryModel> libraryModels = this.listModelsFromLibrary();
362 LibraryModel libraryModel = libraryModels.stream().filter(model -> model.getName().equals(modelName))
363 .findFirst().orElseThrow(
364 () -> new NoSuchElementException(String.format("Model by name '%s' not found", modelName)));
365 LibraryModelDetail libraryModelDetail = this.getLibraryModelDetails(libraryModel);
366 LibraryModelTag libraryModelTag = libraryModelDetail.getTags().stream()
367 .filter(tagName -> tagName.getTag().equals(tag)).findFirst()
368 .orElseThrow(() -> new NoSuchElementException(
369 String.format("Tag '%s' for model '%s' not found", tag, modelName)));
370 return libraryModelTag;
371 }
372
383 public void pullModel(String modelName)
384 throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
385 if (numberOfRetriesForModelPull == 0) {
386 this.doPullModel(modelName);
387 } else {
388 int numberOfRetries = 0;
389 while (numberOfRetries < numberOfRetriesForModelPull) {
390 try {
391 this.doPullModel(modelName);
392 return;
393 } catch (OllamaBaseException e) {
394 logger.error("Failed to pull model " + modelName + ", retrying...");
395 numberOfRetries++;
396 }
397 }
398 throw new OllamaBaseException(
399 "Failed to pull model " + modelName + " after " + numberOfRetriesForModelPull + " retries");
400 }
401 }
402
403 private void doPullModel(String modelName)
404 throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
405 String url = this.host + "/api/pull";
406 String jsonData = new ModelRequest(modelName).toString();
407 HttpRequest request = getRequestBuilderDefault(new URI(url))
408 .POST(HttpRequest.BodyPublishers.ofString(jsonData))
409 .header("Accept", "application/json")
410 .header("Content-type", "application/json")
411 .build();
412 HttpClient client = HttpClient.newHttpClient();
413 HttpResponse<InputStream> response = client.send(request, HttpResponse.BodyHandlers.ofInputStream());
414 int statusCode = response.statusCode();
415 InputStream responseBodyStream = response.body();
416 String responseString = "";
417 boolean success = false; // Flag to check the pull success.
418 try (BufferedReader reader = new BufferedReader(
419 new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
420 String line;
421 while ((line = reader.readLine()) != null) {
422 ModelPullResponse modelPullResponse = Utils.getObjectMapper().readValue(line, ModelPullResponse.class);
423 if (modelPullResponse != null && modelPullResponse.getStatus() != null) {
424 if (verbose) {
425 logger.info(modelName + ": " + modelPullResponse.getStatus());
426 }
427 // Check if status is "success" and set success flag to true.
428 if ("success".equalsIgnoreCase(modelPullResponse.getStatus())) {
429 success = true;
430 }
431 } else {
432 logger.error("Received null or invalid status for model pull.");
433 }
434 }
435 }
436 if (!success) {
437 logger.error("Model pull failed or returned invalid status.");
438 throw new OllamaBaseException("Model pull failed or returned invalid status.");
439 }
440 if (statusCode != 200) {
441 throw new OllamaBaseException(statusCode + " - " + responseString);
442 }
443 }
444
445 public String getVersion() throws URISyntaxException, IOException, InterruptedException, OllamaBaseException {
446 String url = this.host + "/api/version";
447 HttpClient httpClient = HttpClient.newHttpClient();
448 HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
449 .header("Content-type", "application/json").GET().build();
450 HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
451 int statusCode = response.statusCode();
452 String responseString = response.body();
453 if (statusCode == 200) {
454 return Utils.getObjectMapper().readValue(responseString, OllamaVersion.class).getVersion();
455 } else {
456 throw new OllamaBaseException(statusCode + " - " + responseString);
457 }
458 }
459
474 public void pullModel(LibraryModelTag libraryModelTag)
475 throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
476 String tagToPull = String.format("%s:%s", libraryModelTag.getName(), libraryModelTag.getTag());
477 pullModel(tagToPull);
478 }
479
490 public ModelDetail getModelDetails(String modelName)
491 throws IOException, OllamaBaseException, InterruptedException, URISyntaxException {
492 String url = this.host + "/api/show";
493 String jsonData = new ModelRequest(modelName).toString();
494 HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
495 .header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
496 HttpClient client = HttpClient.newHttpClient();
497 HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
498 int statusCode = response.statusCode();
499 String responseBody = response.body();
500 if (statusCode == 200) {
501 return Utils.getObjectMapper().readValue(responseBody, ModelDetail.class);
502 } else {
503 throw new OllamaBaseException(statusCode + " - " + responseBody);
504 }
505 }
506
520 @Deprecated
521 public void createModelWithFilePath(String modelName, String modelFilePath)
522 throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
523 String url = this.host + "/api/create";
524 String jsonData = new CustomModelFilePathRequest(modelName, modelFilePath).toString();
525 HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
526 .header("Content-Type", "application/json")
527 .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
528 HttpClient client = HttpClient.newHttpClient();
529 HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
530 int statusCode = response.statusCode();
531 String responseString = response.body();
532 if (statusCode != 200) {
533 throw new OllamaBaseException(statusCode + " - " + responseString);
534 }
535 // FIXME: Ollama API returns HTTP status code 200 for model creation failure
536 // cases. Correct this
537 // if the issue is fixed in the Ollama API server.
538 if (responseString.contains("error")) {
539 throw new OllamaBaseException(responseString);
540 }
541 if (verbose) {
542 logger.info(responseString);
543 }
544 }
545
560 @Deprecated
561 public void createModelWithModelFileContents(String modelName, String modelFileContents)
562 throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
563 String url = this.host + "/api/create";
564 String jsonData = new CustomModelFileContentsRequest(modelName, modelFileContents).toString();
565 HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
566 .header("Content-Type", "application/json")
567 .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
568 HttpClient client = HttpClient.newHttpClient();
569 HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
570 int statusCode = response.statusCode();
571 String responseString = response.body();
572 if (statusCode != 200) {
573 throw new OllamaBaseException(statusCode + " - " + responseString);
574 }
575 if (responseString.contains("error")) {
576 throw new OllamaBaseException(responseString);
577 }
578 if (verbose) {
579 logger.info(responseString);
580 }
581 }
582
594 public void createModel(CustomModelRequest customModelRequest)
595 throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
596 String url = this.host + "/api/create";
597 String jsonData = customModelRequest.toString();
598 HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
599 .header("Content-Type", "application/json")
600 .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
601 HttpClient client = HttpClient.newHttpClient();
602 HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
603 int statusCode = response.statusCode();
604 String responseString = response.body();
605 if (statusCode != 200) {
606 throw new OllamaBaseException(statusCode + " - " + responseString);
607 }
608 if (responseString.contains("error")) {
609 throw new OllamaBaseException(responseString);
610 }
611 if (verbose) {
612 logger.info(responseString);
613 }
614 }
615
627 public void deleteModel(String modelName, boolean ignoreIfNotPresent)
628 throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
629 String url = this.host + "/api/delete";
630 String jsonData = new ModelRequest(modelName).toString();
631 HttpRequest request = getRequestBuilderDefault(new URI(url))
632 .method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
633 .header("Accept", "application/json").header("Content-type", "application/json").build();
634 HttpClient client = HttpClient.newHttpClient();
635 HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
636 int statusCode = response.statusCode();
637 String responseBody = response.body();
638 if (statusCode == 404 && responseBody.contains("model") && responseBody.contains("not found")) {
639 return;
640 }
641 if (statusCode != 200) {
642 throw new OllamaBaseException(statusCode + " - " + responseBody);
643 }
644 }
645
657 @Deprecated
658 public List<Double> generateEmbeddings(String model, String prompt)
659 throws IOException, InterruptedException, OllamaBaseException {
660 return generateEmbeddings(new OllamaEmbeddingsRequestModel(model, prompt));
661 }
662
673 @Deprecated
674 public List<Double> generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest)
675 throws IOException, InterruptedException, OllamaBaseException {
676 URI uri = URI.create(this.host + "/api/embeddings");
677 String jsonData = modelRequest.toString();
678 HttpClient httpClient = HttpClient.newHttpClient();
679 HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).header("Accept", "application/json")
680 .POST(HttpRequest.BodyPublishers.ofString(jsonData));
681 HttpRequest request = requestBuilder.build();
682 HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
683 int statusCode = response.statusCode();
684 String responseBody = response.body();
685 if (statusCode == 200) {
686 OllamaEmbeddingResponseModel embeddingResponse = Utils.getObjectMapper().readValue(responseBody,
688 return embeddingResponse.getEmbedding();
689 } else {
690 throw new OllamaBaseException(statusCode + " - " + responseBody);
691 }
692 }
693
704 public OllamaEmbedResponseModel embed(String model, List<String> inputs)
705 throws IOException, InterruptedException, OllamaBaseException {
706 return embed(new OllamaEmbedRequestModel(model, inputs));
707 }
708
719 throws IOException, InterruptedException, OllamaBaseException {
720 URI uri = URI.create(this.host + "/api/embed");
721 String jsonData = Utils.getObjectMapper().writeValueAsString(modelRequest);
722 HttpClient httpClient = HttpClient.newHttpClient();
723
724 HttpRequest request = HttpRequest.newBuilder(uri).header("Accept", "application/json")
725 .POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
726
727 HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
728 int statusCode = response.statusCode();
729 String responseBody = response.body();
730
731 if (statusCode == 200) {
732 return Utils.getObjectMapper().readValue(responseBody, OllamaEmbedResponseModel.class);
733 } else {
734 throw new OllamaBaseException(statusCode + " - " + responseBody);
735 }
736 }
737
757 public OllamaResult generate(String model, String prompt, boolean raw, Options options,
758 OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
759 OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
760 ollamaRequestModel.setRaw(raw);
761 ollamaRequestModel.setOptions(options.getOptionsMap());
762 return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
763 }
764
779 public OllamaResult generate(String model, String prompt, Map<String, Object> format)
780 throws OllamaBaseException, IOException, InterruptedException {
781 URI uri = URI.create(this.host + "/api/generate");
782
783 Map<String, Object> requestBody = new HashMap<>();
784 requestBody.put("model", model);
785 requestBody.put("prompt", prompt);
786 requestBody.put("stream", false);
787 requestBody.put("format", format);
788
789 String jsonData = Utils.getObjectMapper().writeValueAsString(requestBody);
790 HttpClient httpClient = HttpClient.newHttpClient();
791
792 HttpRequest request = HttpRequest.newBuilder(uri)
793 .header("Content-Type", "application/json")
794 .POST(HttpRequest.BodyPublishers.ofString(jsonData))
795 .build();
796
797 HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
798 int statusCode = response.statusCode();
799 String responseBody = response.body();
800
801 if (statusCode == 200) {
802 OllamaStructuredResult structuredResult = Utils.getObjectMapper().readValue(responseBody,
804 OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(),
805 structuredResult.getResponseTime(), statusCode);
806 return ollamaResult;
807 } else {
808 throw new OllamaBaseException(statusCode + " - " + responseBody);
809 }
810 }
811
832 public OllamaResult generate(String model, String prompt, boolean raw, Options options)
833 throws OllamaBaseException, IOException, InterruptedException {
834 return generate(model, prompt, raw, options, null);
835 }
836
854 public OllamaToolsResult generateWithTools(String model, String prompt, Options options)
855 throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
856 boolean raw = true;
857 OllamaToolsResult toolResult = new OllamaToolsResult();
858 Map<ToolFunctionCallSpec, Object> toolResults = new HashMap<>();
859
860 if (!prompt.startsWith("[AVAILABLE_TOOLS]")) {
861 final Tools.PromptBuilder promptBuilder = new Tools.PromptBuilder();
862 for (Tools.ToolSpecification spec : toolRegistry.getRegisteredSpecs()) {
863 promptBuilder.withToolSpecification(spec);
864 }
865 promptBuilder.withPrompt(prompt);
866 prompt = promptBuilder.build();
867 }
868
869 OllamaResult result = generate(model, prompt, raw, options, null);
870 toolResult.setModelResult(result);
871
872 String toolsResponse = result.getResponse();
873 if (toolsResponse.contains("[TOOL_CALLS]")) {
874 toolsResponse = toolsResponse.replace("[TOOL_CALLS]", "");
875 }
876
877 List<ToolFunctionCallSpec> toolFunctionCallSpecs = new ArrayList<>();
878 ObjectMapper objectMapper = Utils.getObjectMapper();
879
880 if (!toolsResponse.isEmpty()) {
881 try {
882 // Try to parse the string to see if it's a valid JSON
883 JsonNode jsonNode = objectMapper.readTree(toolsResponse);
884 } catch (JsonParseException e) {
885 logger.warn("Response from model does not contain any tool calls. Returning the response as is.");
886 return toolResult;
887 }
888 toolFunctionCallSpecs = objectMapper.readValue(
889 toolsResponse,
890 objectMapper.getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class));
891 }
892 for (ToolFunctionCallSpec toolFunctionCallSpec : toolFunctionCallSpecs) {
893 toolResults.put(toolFunctionCallSpec, invokeTool(toolFunctionCallSpec));
894 }
895 toolResult.setToolResults(toolResults);
896 return toolResult;
897 }
898
910 public OllamaAsyncResultStreamer generateAsync(String model, String prompt, boolean raw) {
911 OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
912 ollamaRequestModel.setRaw(raw);
913 URI uri = URI.create(this.host + "/api/generate");
914 OllamaAsyncResultStreamer ollamaAsyncResultStreamer = new OllamaAsyncResultStreamer(
915 getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds);
916 ollamaAsyncResultStreamer.start();
917 return ollamaAsyncResultStreamer;
918 }
919
940 public OllamaResult generateWithImageFiles(String model, String prompt, List<File> imageFiles, Options options,
941 OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
942 List<String> images = new ArrayList<>();
943 for (File imageFile : imageFiles) {
944 images.add(encodeFileToBase64(imageFile));
945 }
946 OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt, images);
947 ollamaRequestModel.setOptions(options.getOptionsMap());
948 return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
949 }
950
961 public OllamaResult generateWithImageFiles(String model, String prompt, List<File> imageFiles, Options options)
962 throws OllamaBaseException, IOException, InterruptedException {
963 return generateWithImageFiles(model, prompt, imageFiles, options, null);
964 }
965
987 public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs, Options options,
988 OllamaStreamHandler streamHandler)
989 throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
990 List<String> images = new ArrayList<>();
991 for (String imageURL : imageURLs) {
992 images.add(encodeByteArrayToBase64(Utils.loadImageBytesFromUrl(imageURL)));
993 }
994 OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt, images);
995 ollamaRequestModel.setOptions(options.getOptionsMap());
996 return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
997 }
998
1010 public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs, Options options)
1011 throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
1012 return generateWithImageURLs(model, prompt, imageURLs, options, null);
1013 }
1014
1030 public OllamaResult generateWithImages(String model, String prompt, List<byte[]> images, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
1031 List<String> encodedImages = new ArrayList<>();
1032 for (byte[] image : images) {
1033 encodedImages.add(encodeByteArrayToBase64(image));
1034 }
1035 OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt, encodedImages);
1036 ollamaRequestModel.setOptions(options.getOptionsMap());
1037 return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
1038 }
1039
1049 public OllamaResult generateWithImages(String model, String prompt, List<byte[]> images, Options options) throws OllamaBaseException, IOException, InterruptedException {
1050 return generateWithImages(model, prompt, images, options, null);
1051 }
1052
1071 public OllamaChatResult chat(String model, List<OllamaChatMessage> messages)
1072 throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
1074 return chat(builder.withMessages(messages).build());
1075 }
1076
1095 throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
1096 return chat(request, null);
1097 }
1098
1120 throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
1121 return chatStreaming(request, new OllamaChatStreamObserver(streamHandler));
1122 }
1123
1144 throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
1145 OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds,
1146 verbose);
1147 OllamaChatResult result;
1148
1149 // add all registered tools to Request
1150 request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt)
1151 .collect(Collectors.toList()));
1152
1153 if (tokenHandler != null) {
1154 request.setStream(true);
1155 result = requestCaller.call(request, tokenHandler);
1156 } else {
1157 result = requestCaller.callSync(request);
1158 }
1159
1160 // check if toolCallIsWanted
1161 List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls();
1162 int toolCallTries = 0;
1163 while (toolCalls != null && !toolCalls.isEmpty() && toolCallTries < maxChatToolCallRetries) {
1164 for (OllamaChatToolCalls toolCall : toolCalls) {
1165 String toolName = toolCall.getFunction().getName();
1166 ToolFunction toolFunction = toolRegistry.getToolFunction(toolName);
1167 if (toolFunction == null) {
1168 throw new ToolInvocationException("Tool function not found: " + toolName);
1169 }
1170 Map<String, Object> arguments = toolCall.getFunction().getArguments();
1171 Object res = toolFunction.apply(arguments);
1173 "[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() + ") : " + res + "[/TOOL_RESULTS]"));
1174 }
1175
1176 if (tokenHandler != null) {
1177 result = requestCaller.call(request, tokenHandler);
1178 } else {
1179 result = requestCaller.callSync(request);
1180 }
1181 toolCalls = result.getResponseModel().getMessage().getToolCalls();
1182 toolCallTries++;
1183 }
1184
1185 return result;
1186 }
1187
1196 public void registerTool(Tools.ToolSpecification toolSpecification) {
1197 toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
1198 if (this.verbose) {
1199 logger.debug("Registered tool: {}", toolSpecification.getFunctionName());
1200 }
1201 }
1202
1213 public void registerTools(List<Tools.ToolSpecification> toolSpecifications) {
1214 for (Tools.ToolSpecification toolSpecification : toolSpecifications) {
1215 toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
1216 }
1217 }
1218
1232 try {
1233 Class<?> callerClass = null;
1234 try {
1235 callerClass = Class.forName(Thread.currentThread().getStackTrace()[2].getClassName());
1236 } catch (ClassNotFoundException e) {
1237 throw new RuntimeException(e);
1238 }
1239
1240 OllamaToolService ollamaToolServiceAnnotation = callerClass.getDeclaredAnnotation(OllamaToolService.class);
1241 if (ollamaToolServiceAnnotation == null) {
1242 throw new IllegalStateException(callerClass + " is not annotated as " + OllamaToolService.class);
1243 }
1244
1245 Class<?>[] providers = ollamaToolServiceAnnotation.providers();
1246 for (Class<?> provider : providers) {
1247 registerAnnotatedTools(provider.getDeclaredConstructor().newInstance());
1248 }
1249 } catch (InstantiationException | NoSuchMethodException | IllegalAccessException
1250 | InvocationTargetException e) {
1251 throw new RuntimeException(e);
1252 }
1253 }
1254
1268 public void registerAnnotatedTools(Object object) {
1269 Class<?> objectClass = object.getClass();
1270 Method[] methods = objectClass.getMethods();
1271 for (Method m : methods) {
1272 ToolSpec toolSpec = m.getDeclaredAnnotation(ToolSpec.class);
1273 if (toolSpec == null) {
1274 continue;
1275 }
1276 String operationName = !toolSpec.name().isBlank() ? toolSpec.name() : m.getName();
1277 String operationDesc = !toolSpec.desc().isBlank() ? toolSpec.desc() : operationName;
1278
1279 final Tools.PropsBuilder propsBuilder = new Tools.PropsBuilder();
1280 LinkedHashMap<String, String> methodParams = new LinkedHashMap<>();
1281 for (Parameter parameter : m.getParameters()) {
1282 final ToolProperty toolPropertyAnn = parameter.getDeclaredAnnotation(ToolProperty.class);
1283 String propType = parameter.getType().getTypeName();
1284 if (toolPropertyAnn == null) {
1285 methodParams.put(parameter.getName(), null);
1286 continue;
1287 }
1288 String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName();
1289 methodParams.put(propName, propType);
1290 propsBuilder.withProperty(propName, Tools.PromptFuncDefinition.Property.builder().type(propType)
1291 .description(toolPropertyAnn.desc()).required(toolPropertyAnn.required()).build());
1292 }
1293 final Map<String, Tools.PromptFuncDefinition.Property> params = propsBuilder.build();
1294 List<String> reqProps = params.entrySet().stream().filter(e -> e.getValue().isRequired())
1295 .map(Map.Entry::getKey).collect(Collectors.toList());
1296
1297 Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder().functionName(operationName)
1298 .functionDescription(operationDesc)
1299 .toolPrompt(Tools.PromptFuncDefinition.builder().type("function")
1300 .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name(operationName)
1301 .description(operationDesc).parameters(Tools.PromptFuncDefinition.Parameters
1302 .builder().type("object").properties(params).required(reqProps).build())
1303 .build())
1304 .build())
1305 .build();
1306
1307 ReflectionalToolFunction reflectionalToolFunction = new ReflectionalToolFunction(object, m, methodParams);
1308 toolSpecification.setToolFunction(reflectionalToolFunction);
1309 toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
1310 }
1311
1312 }
1313
1320 public OllamaChatMessageRole addCustomRole(String roleName) {
1321 return OllamaChatMessageRole.newCustomRole(roleName);
1322 }
1323
1329 public List<OllamaChatMessageRole> listRoles() {
1331 }
1332
1341 public OllamaChatMessageRole getRole(String roleName) throws RoleNotFoundException {
1342 return OllamaChatMessageRole.getRole(roleName);
1343 }
1344
1345 // technical private methods //
1346
1354 private static String encodeFileToBase64(File file) throws IOException {
1355 return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
1356 }
1357
1364 private static String encodeByteArrayToBase64(byte[] bytes) {
1365 return Base64.getEncoder().encodeToString(bytes);
1366 }
1367
1386 private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel,
1387 OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
1388 OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds,
1389 verbose);
1390 OllamaResult result;
1391 if (streamHandler != null) {
1392 ollamaRequestModel.setStream(true);
1393 result = requestCaller.call(ollamaRequestModel, streamHandler);
1394 } else {
1395 result = requestCaller.callSync(ollamaRequestModel);
1396 }
1397 return result;
1398 }
1399
1406 private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
1407 HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header("Content-Type", "application/json")
1408 .timeout(Duration.ofSeconds(requestTimeoutSeconds));
1409 if (isBasicAuthCredentialsSet()) {
1410 requestBuilder.header("Authorization", auth.getAuthHeaderValue());
1411 }
1412 return requestBuilder;
1413 }
1414
1420 private boolean isBasicAuthCredentialsSet() {
1421 return auth != null;
1422 }
1423
1424 private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec) throws ToolInvocationException {
1425 try {
1426 String methodName = toolFunctionCallSpec.getName();
1427 Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
1428 ToolFunction function = toolRegistry.getToolFunction(methodName);
1429 if (verbose) {
1430 logger.debug("Invoking function {} with arguments {}", methodName, arguments);
1431 }
1432 if (function == null) {
1433 throw new ToolNotFoundException(
1434 "No such tool: " + methodName + ". Please register the tool before invoking it.");
1435 }
1436 return function.apply(arguments);
1437 } catch (Exception e) {
1438 throw new ToolInvocationException("Failed to invoke tool: " + toolFunctionCallSpec.getName(), e);
1439 }
1440 }
1441}
OllamaResult generateWithImageFiles(String model, String prompt, List< File > imageFiles, Options options, OllamaStreamHandler streamHandler)
OllamaChatResult chat(OllamaChatRequest request)
ModelDetail getModelDetails(String modelName)
List< Double > generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest)
void setBearerAuth(String bearerToken)
OllamaResult generateWithImages(String model, String prompt, List< byte[]> images, Options options, OllamaStreamHandler streamHandler)
OllamaAsyncResultStreamer generateAsync(String model, String prompt, boolean raw)
OllamaResult generate(String model, String prompt, Map< String, Object > format)
OllamaChatResult chat(String model, List< OllamaChatMessage > messages)
List< Double > generateEmbeddings(String model, String prompt)
void pullModel(String modelName)
void setBasicAuth(String username, String password)
OllamaResult generate(String model, String prompt, boolean raw, Options options, OllamaStreamHandler streamHandler)
OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler)
void deleteModel(String modelName, boolean ignoreIfNotPresent)
void createModelWithFilePath(String modelName, String modelFilePath)
OllamaResult generate(String model, String prompt, boolean raw, Options options)
ModelsProcessResponse ps()
LibraryModelTag findModelTagFromLibrary(String modelName, String tag)
OllamaChatMessageRole getRole(String roleName)
OllamaChatMessageRole addCustomRole(String roleName)
void pullModel(LibraryModelTag libraryModelTag)
OllamaResult generateWithImages(String model, String prompt, List< byte[]> images, Options options)
List< OllamaChatMessageRole > listRoles()
LibraryModelDetail getLibraryModelDetails(LibraryModel libraryModel)
OllamaResult generateWithImageURLs(String model, String prompt, List< String > imageURLs, Options options, OllamaStreamHandler streamHandler)
OllamaEmbedResponseModel embed(OllamaEmbedRequestModel modelRequest)
void registerTool(Tools.ToolSpecification toolSpecification)
void setNumberOfRetriesForModelPull(int numberOfRetriesForModelPull)
void registerAnnotatedTools(Object object)
List< LibraryModel > listModelsFromLibrary()
void registerTools(List< Tools.ToolSpecification > toolSpecifications)
OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler)
OllamaToolsResult generateWithTools(String model, String prompt, Options options)
void createModel(CustomModelRequest customModelRequest)
OllamaResult generateWithImageURLs(String model, String prompt, List< String > imageURLs, Options options)
OllamaEmbedResponseModel embed(String model, List< String > inputs)
void createModelWithModelFileContents(String modelName, String modelFileContents)
OllamaResult generateWithImageFiles(String model, String prompt, List< File > imageFiles, Options options)
static OllamaChatMessageRole newCustomRole(String roleName)
static OllamaChatMessageRole getRole(String roleName)
OllamaChatRequestBuilder withMessages(List< OllamaChatMessage > messages)
static OllamaChatRequestBuilder getInstance(String model)
OllamaChatResult call(OllamaChatRequest body, OllamaTokenHandler tokenHandler)
OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
ToolFunction getToolFunction(String name)
static byte[] loadImageBytesFromUrl(String imageUrl)
Definition Utils.java:25
static ObjectMapper getObjectMapper()
Definition Utils.java:17
Object apply(Map< String, Object > arguments)