Ollama4j
A Java library (wrapper/binding) for Ollama server.
Loading...
Searching...
No Matches
OllamaAPI.java
Go to the documentation of this file.
1package io.github.ollama4j;
2
3import com.fasterxml.jackson.core.JsonParseException;
4import com.fasterxml.jackson.databind.JsonNode;
5import com.fasterxml.jackson.databind.ObjectMapper;
6import io.github.ollama4j.exceptions.OllamaBaseException;
7import io.github.ollama4j.exceptions.RoleNotFoundException;
8import io.github.ollama4j.exceptions.ToolInvocationException;
9import io.github.ollama4j.exceptions.ToolNotFoundException;
10import io.github.ollama4j.models.chat.*;
11import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
12import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
13import io.github.ollama4j.models.embeddings.OllamaEmbeddingResponseModel;
14import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
15import io.github.ollama4j.models.generate.OllamaGenerateRequest;
16import io.github.ollama4j.models.generate.OllamaStreamHandler;
17import io.github.ollama4j.models.generate.OllamaTokenHandler;
18import io.github.ollama4j.models.ps.ModelsProcessResponse;
19import io.github.ollama4j.models.request.*;
20import io.github.ollama4j.models.response.*;
21import io.github.ollama4j.tools.*;
22import io.github.ollama4j.tools.annotations.OllamaToolService;
23import io.github.ollama4j.tools.annotations.ToolProperty;
24import io.github.ollama4j.tools.annotations.ToolSpec;
25import io.github.ollama4j.utils.Options;
26import io.github.ollama4j.utils.Utils;
27import lombok.Setter;
28import org.jsoup.Jsoup;
29import org.jsoup.nodes.Document;
30import org.jsoup.nodes.Element;
31import org.jsoup.select.Elements;
32import org.slf4j.Logger;
33import org.slf4j.LoggerFactory;
34
35import java.io.*;
36import java.lang.reflect.InvocationTargetException;
37import java.lang.reflect.Method;
38import java.lang.reflect.Parameter;
39import java.net.URI;
40import java.net.URISyntaxException;
41import java.net.http.HttpClient;
42import java.net.http.HttpConnectTimeoutException;
43import java.net.http.HttpRequest;
44import java.net.http.HttpResponse;
45import java.nio.charset.StandardCharsets;
46import java.nio.file.Files;
47import java.time.Duration;
48import java.util.*;
49import java.util.stream.Collectors;
50
54@SuppressWarnings({"DuplicatedCode", "resource"})
55public class OllamaAPI {
56
57 private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class);
58 private final String host;
63 @Setter
64 private long requestTimeoutSeconds = 10;
69 @Setter
70 private boolean verbose = true;
71
72 @Setter
73 private int maxChatToolCallRetries = 3;
74
75 private BasicAuth basicAuth;
76
77 private final ToolRegistry toolRegistry = new ToolRegistry();
78
82 public OllamaAPI() {
83 this.host = "http://localhost:11434";
84 }
85
91 public OllamaAPI(String host) {
92 if (host.endsWith("/")) {
93 this.host = host.substring(0, host.length() - 1);
94 } else {
95 this.host = host;
96 }
97 if (this.verbose) {
98 logger.info("Ollama API initialized with host: " + this.host);
99 }
100 }
101
108 public void setBasicAuth(String username, String password) {
109 this.basicAuth = new BasicAuth(username, password);
110 }
111
117 public boolean ping() {
118 String url = this.host + "/api/tags";
119 HttpClient httpClient = HttpClient.newHttpClient();
120 HttpRequest httpRequest = null;
121 try {
122 httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
123 } catch (URISyntaxException e) {
124 throw new RuntimeException(e);
125 }
126 HttpResponse<String> response = null;
127 try {
128 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
129 } catch (HttpConnectTimeoutException e) {
130 return false;
131 } catch (IOException | InterruptedException e) {
132 throw new RuntimeException(e);
133 }
134 int statusCode = response.statusCode();
135 return statusCode == 200;
136 }
137
146 public ModelsProcessResponse ps() throws IOException, InterruptedException, OllamaBaseException {
147 String url = this.host + "/api/ps";
148 HttpClient httpClient = HttpClient.newHttpClient();
149 HttpRequest httpRequest = null;
150 try {
151 httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
152 } catch (URISyntaxException e) {
153 throw new RuntimeException(e);
154 }
155 HttpResponse<String> response = null;
156 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
157 int statusCode = response.statusCode();
158 String responseString = response.body();
159 if (statusCode == 200) {
160 return Utils.getObjectMapper().readValue(responseString, ModelsProcessResponse.class);
161 } else {
162 throw new OllamaBaseException(statusCode + " - " + responseString);
163 }
164 }
165
175 public List<Model> listModels() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
176 String url = this.host + "/api/tags";
177 HttpClient httpClient = HttpClient.newHttpClient();
178 HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
179 HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
180 int statusCode = response.statusCode();
181 String responseString = response.body();
182 if (statusCode == 200) {
183 return Utils.getObjectMapper().readValue(responseString, ListModelsResponse.class).getModels();
184 } else {
185 throw new OllamaBaseException(statusCode + " - " + responseString);
186 }
187 }
188
199 public List<LibraryModel> listModelsFromLibrary() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
200 String url = "https://ollama.com/library";
201 HttpClient httpClient = HttpClient.newHttpClient();
202 HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
203 HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
204 int statusCode = response.statusCode();
205 String responseString = response.body();
206 List<LibraryModel> models = new ArrayList<>();
207 if (statusCode == 200) {
208 Document doc = Jsoup.parse(responseString);
209 Elements modelSections = doc.selectXpath("//*[@id='repo']/ul/li/a");
210 for (Element e : modelSections) {
211 LibraryModel model = new LibraryModel();
212 Elements names = e.select("div > h2 > div > span");
213 Elements desc = e.select("div > p");
214 Elements pullCounts = e.select("div:nth-of-type(2) > p > span:first-of-type > span:first-of-type");
215 Elements popularTags = e.select("div > div > span");
216 Elements totalTags = e.select("div:nth-of-type(2) > p > span:nth-of-type(2) > span:first-of-type");
217 Elements lastUpdatedTime = e.select("div:nth-of-type(2) > p > span:nth-of-type(3) > span:nth-of-type(2)");
218
219 if (names.first() == null || names.isEmpty()) {
220 // if name cannot be extracted, skip.
221 continue;
222 }
223 Optional.ofNullable(names.first()).map(Element::text).ifPresent(model::setName);
224 model.setDescription(Optional.ofNullable(desc.first()).map(Element::text).orElse(""));
225 model.setPopularTags(Optional.of(popularTags).map(tags -> tags.stream().map(Element::text).collect(Collectors.toList())).orElse(new ArrayList<>()));
226 model.setPullCount(Optional.ofNullable(pullCounts.first()).map(Element::text).orElse(""));
227 model.setTotalTags(Optional.ofNullable(totalTags.first()).map(Element::text).map(Integer::parseInt).orElse(0));
228 model.setLastUpdated(Optional.ofNullable(lastUpdatedTime.first()).map(Element::text).orElse(""));
229
230 models.add(model);
231 }
232 return models;
233 } else {
234 throw new OllamaBaseException(statusCode + " - " + responseString);
235 }
236 }
237
252 public LibraryModelDetail getLibraryModelDetails(LibraryModel libraryModel) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
253 String url = String.format("https://ollama.com/library/%s/tags", libraryModel.getName());
254 HttpClient httpClient = HttpClient.newHttpClient();
255 HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
256 HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
257 int statusCode = response.statusCode();
258 String responseString = response.body();
259
260 List<LibraryModelTag> libraryModelTags = new ArrayList<>();
261 if (statusCode == 200) {
262 Document doc = Jsoup.parse(responseString);
263 Elements tagSections = doc.select("html > body > main > div > section > div > div > div:nth-child(n+2) > div");
264 for (Element e : tagSections) {
265 Elements tags = e.select("div > a > div");
266 Elements tagsMetas = e.select("div > span");
267
268 LibraryModelTag libraryModelTag = new LibraryModelTag();
269
270 if (tags.first() == null || tags.isEmpty()) {
271 // if tag cannot be extracted, skip.
272 continue;
273 }
274 libraryModelTag.setName(libraryModel.getName());
275 Optional.ofNullable(tags.first()).map(Element::text).ifPresent(libraryModelTag::setTag);
276 libraryModelTag.setSize(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("•")).filter(parts -> parts.length > 1).map(parts -> parts[1].trim()).orElse(""));
277 libraryModelTag.setLastUpdated(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("•")).filter(parts -> parts.length > 1).map(parts -> parts[2].trim()).orElse(""));
278 libraryModelTags.add(libraryModelTag);
279 }
280 LibraryModelDetail libraryModelDetail = new LibraryModelDetail();
281 libraryModelDetail.setModel(libraryModel);
282 libraryModelDetail.setTags(libraryModelTags);
283 return libraryModelDetail;
284 } else {
285 throw new OllamaBaseException(statusCode + " - " + responseString);
286 }
287 }
288
305 public LibraryModelTag findModelTagFromLibrary(String modelName, String tag) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
306 List<LibraryModel> libraryModels = this.listModelsFromLibrary();
307 LibraryModel libraryModel = libraryModels.stream().filter(model -> model.getName().equals(modelName)).findFirst().orElseThrow(() -> new NoSuchElementException(String.format("Model by name '%s' not found", modelName)));
308 LibraryModelDetail libraryModelDetail = this.getLibraryModelDetails(libraryModel);
309 LibraryModelTag libraryModelTag = libraryModelDetail.getTags().stream().filter(tagName -> tagName.getTag().equals(tag)).findFirst().orElseThrow(() -> new NoSuchElementException(String.format("Tag '%s' for model '%s' not found", tag, modelName)));
310 return libraryModelTag;
311 }
312
323 public void pullModel(String modelName) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
324 String url = this.host + "/api/pull";
325 String jsonData = new ModelRequest(modelName).toString();
326 HttpRequest request = getRequestBuilderDefault(new URI(url))
327 .POST(HttpRequest.BodyPublishers.ofString(jsonData))
328 .header("Accept", "application/json")
329 .header("Content-type", "application/json")
330 .build();
331 HttpClient client = HttpClient.newHttpClient();
332 HttpResponse<InputStream> response = client.send(request, HttpResponse.BodyHandlers.ofInputStream());
333 int statusCode = response.statusCode();
334 InputStream responseBodyStream = response.body();
335 String responseString = "";
336 boolean success = false; // Flag to check the pull success.
337 try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
338 String line;
339 while ((line = reader.readLine()) != null) {
340 ModelPullResponse modelPullResponse = Utils.getObjectMapper().readValue(line, ModelPullResponse.class);
341 if (modelPullResponse != null && modelPullResponse.getStatus() != null) {
342 if (verbose) {
343 logger.info(modelName + ": " + modelPullResponse.getStatus());
344 }
345 // Check if status is "success" and set success flag to true.
346 if ("success".equalsIgnoreCase(modelPullResponse.getStatus())) {
347 success = true;
348 }
349 } else {
350 logger.error("Received null or invalid status for model pull.");
351 }
352 }
353 }
354 if (!success) {
355 logger.error("Model pull failed or returned invalid status.");
356 throw new OllamaBaseException("Model pull failed or returned invalid status.");
357 }
358 if (statusCode != 200) {
359 throw new OllamaBaseException(statusCode + " - " + responseString);
360 }
361 }
362
363
364 public String getVersion() throws URISyntaxException, IOException, InterruptedException, OllamaBaseException {
365 String url = this.host + "/api/version";
366 HttpClient httpClient = HttpClient.newHttpClient();
367 HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
368 HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
369 int statusCode = response.statusCode();
370 String responseString = response.body();
371 if (statusCode == 200) {
372 return Utils.getObjectMapper().readValue(responseString, OllamaVersion.class).getVersion();
373 } else {
374 throw new OllamaBaseException(statusCode + " - " + responseString);
375 }
376 }
377
390 public void pullModel(LibraryModelTag libraryModelTag) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
391 String tagToPull = String.format("%s:%s", libraryModelTag.getName(), libraryModelTag.getTag());
392 pullModel(tagToPull);
393 }
394
405 public ModelDetail getModelDetails(String modelName) throws IOException, OllamaBaseException, InterruptedException, URISyntaxException {
406 String url = this.host + "/api/show";
407 String jsonData = new ModelRequest(modelName).toString();
408 HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
409 HttpClient client = HttpClient.newHttpClient();
410 HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
411 int statusCode = response.statusCode();
412 String responseBody = response.body();
413 if (statusCode == 200) {
414 return Utils.getObjectMapper().readValue(responseBody, ModelDetail.class);
415 } else {
416 throw new OllamaBaseException(statusCode + " - " + responseBody);
417 }
418 }
419
431 @Deprecated
432 public void createModelWithFilePath(String modelName, String modelFilePath) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
433 String url = this.host + "/api/create";
434 String jsonData = new CustomModelFilePathRequest(modelName, modelFilePath).toString();
435 HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
436 HttpClient client = HttpClient.newHttpClient();
437 HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
438 int statusCode = response.statusCode();
439 String responseString = response.body();
440 if (statusCode != 200) {
441 throw new OllamaBaseException(statusCode + " - " + responseString);
442 }
443 // FIXME: Ollama API returns HTTP status code 200 for model creation failure cases. Correct this
444 // if the issue is fixed in the Ollama API server.
445 if (responseString.contains("error")) {
446 throw new OllamaBaseException(responseString);
447 }
448 if (verbose) {
449 logger.info(responseString);
450 }
451 }
452
464 @Deprecated
465 public void createModelWithModelFileContents(String modelName, String modelFileContents) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
466 String url = this.host + "/api/create";
467 String jsonData = new CustomModelFileContentsRequest(modelName, modelFileContents).toString();
468 HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
469 HttpClient client = HttpClient.newHttpClient();
470 HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
471 int statusCode = response.statusCode();
472 String responseString = response.body();
473 if (statusCode != 200) {
474 throw new OllamaBaseException(statusCode + " - " + responseString);
475 }
476 if (responseString.contains("error")) {
477 throw new OllamaBaseException(responseString);
478 }
479 if (verbose) {
480 logger.info(responseString);
481 }
482 }
483
494 public void createModel(CustomModelRequest customModelRequest) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
495 String url = this.host + "/api/create";
496 String jsonData = customModelRequest.toString();
497 HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
498 HttpClient client = HttpClient.newHttpClient();
499 HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
500 int statusCode = response.statusCode();
501 String responseString = response.body();
502 if (statusCode != 200) {
503 throw new OllamaBaseException(statusCode + " - " + responseString);
504 }
505 if (responseString.contains("error")) {
506 throw new OllamaBaseException(responseString);
507 }
508 if (verbose) {
509 logger.info(responseString);
510 }
511 }
512
523 public void deleteModel(String modelName, boolean ignoreIfNotPresent) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
524 String url = this.host + "/api/delete";
525 String jsonData = new ModelRequest(modelName).toString();
526 HttpRequest request = getRequestBuilderDefault(new URI(url)).method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).header("Accept", "application/json").header("Content-type", "application/json").build();
527 HttpClient client = HttpClient.newHttpClient();
528 HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
529 int statusCode = response.statusCode();
530 String responseBody = response.body();
531 if (statusCode == 404 && responseBody.contains("model") && responseBody.contains("not found")) {
532 return;
533 }
534 if (statusCode != 200) {
535 throw new OllamaBaseException(statusCode + " - " + responseBody);
536 }
537 }
538
550 @Deprecated
551 public List<Double> generateEmbeddings(String model, String prompt) throws IOException, InterruptedException, OllamaBaseException {
552 return generateEmbeddings(new OllamaEmbeddingsRequestModel(model, prompt));
553 }
554
565 @Deprecated
566 public List<Double> generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException {
567 URI uri = URI.create(this.host + "/api/embeddings");
568 String jsonData = modelRequest.toString();
569 HttpClient httpClient = HttpClient.newHttpClient();
570 HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).header("Accept", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData));
571 HttpRequest request = requestBuilder.build();
572 HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
573 int statusCode = response.statusCode();
574 String responseBody = response.body();
575 if (statusCode == 200) {
576 OllamaEmbeddingResponseModel embeddingResponse = Utils.getObjectMapper().readValue(responseBody, OllamaEmbeddingResponseModel.class);
577 return embeddingResponse.getEmbedding();
578 } else {
579 throw new OllamaBaseException(statusCode + " - " + responseBody);
580 }
581 }
582
593 public OllamaEmbedResponseModel embed(String model, List<String> inputs) throws IOException, InterruptedException, OllamaBaseException {
594 return embed(new OllamaEmbedRequestModel(model, inputs));
595 }
596
606 public OllamaEmbedResponseModel embed(OllamaEmbedRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException {
607 URI uri = URI.create(this.host + "/api/embed");
608 String jsonData = Utils.getObjectMapper().writeValueAsString(modelRequest);
609 HttpClient httpClient = HttpClient.newHttpClient();
610
611 HttpRequest request = HttpRequest.newBuilder(uri).header("Accept", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
612
613 HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
614 int statusCode = response.statusCode();
615 String responseBody = response.body();
616
617 if (statusCode == 200) {
618 return Utils.getObjectMapper().readValue(responseBody, OllamaEmbedResponseModel.class);
619 } else {
620 throw new OllamaBaseException(statusCode + " - " + responseBody);
621 }
622 }
623
639 public OllamaResult generate(String model, String prompt, boolean raw, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
640 OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
641 ollamaRequestModel.setRaw(raw);
642 ollamaRequestModel.setOptions(options.getOptionsMap());
643 return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
644 }
645
660 public OllamaResult generate(String model, String prompt, boolean raw, Options options) throws OllamaBaseException, IOException, InterruptedException {
661 return generate(model, prompt, raw, options, null);
662 }
663
676 public OllamaToolsResult generateWithTools(String model, String prompt, Options options) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
677 boolean raw = true;
678 OllamaToolsResult toolResult = new OllamaToolsResult();
679 Map<ToolFunctionCallSpec, Object> toolResults = new HashMap<>();
680
681 if (!prompt.startsWith("[AVAILABLE_TOOLS]")) {
682 final Tools.PromptBuilder promptBuilder = new Tools.PromptBuilder();
683 for (Tools.ToolSpecification spec : toolRegistry.getRegisteredSpecs()) {
684 promptBuilder.withToolSpecification(spec);
685 }
686 promptBuilder.withPrompt(prompt);
687 prompt = promptBuilder.build();
688 }
689
690 OllamaResult result = generate(model, prompt, raw, options, null);
691 toolResult.setModelResult(result);
692
693 String toolsResponse = result.getResponse();
694 if (toolsResponse.contains("[TOOL_CALLS]")) {
695 toolsResponse = toolsResponse.replace("[TOOL_CALLS]", "");
696 }
697
698 List<ToolFunctionCallSpec> toolFunctionCallSpecs = new ArrayList<>();
699 ObjectMapper objectMapper = Utils.getObjectMapper();
700
701 if (!toolsResponse.isEmpty()) {
702 try {
703 // Try to parse the string to see if it's a valid JSON
704 JsonNode jsonNode = objectMapper.readTree(toolsResponse);
705 } catch (JsonParseException e) {
706 logger.warn("Response from model does not contain any tool calls. Returning the response as is.");
707 return toolResult;
708 }
709 toolFunctionCallSpecs = objectMapper.readValue(
710 toolsResponse,
711 objectMapper.getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class)
712 );
713 }
714 for (ToolFunctionCallSpec toolFunctionCallSpec : toolFunctionCallSpecs) {
715 toolResults.put(toolFunctionCallSpec, invokeTool(toolFunctionCallSpec));
716 }
717 toolResult.setToolResults(toolResults);
718 return toolResult;
719 }
720
730 public OllamaAsyncResultStreamer generateAsync(String model, String prompt, boolean raw) {
731 OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
732 ollamaRequestModel.setRaw(raw);
733 URI uri = URI.create(this.host + "/api/generate");
734 OllamaAsyncResultStreamer ollamaAsyncResultStreamer = new OllamaAsyncResultStreamer(getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds);
735 ollamaAsyncResultStreamer.start();
736 return ollamaAsyncResultStreamer;
737 }
738
755 public OllamaResult generateWithImageFiles(String model, String prompt, List<File> imageFiles, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
756 List<String> images = new ArrayList<>();
757 for (File imageFile : imageFiles) {
758 images.add(encodeFileToBase64(imageFile));
759 }
760 OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt, images);
761 ollamaRequestModel.setOptions(options.getOptionsMap());
762 return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
763 }
764
774 public OllamaResult generateWithImageFiles(String model, String prompt, List<File> imageFiles, Options options) throws OllamaBaseException, IOException, InterruptedException {
775 return generateWithImageFiles(model, prompt, imageFiles, options, null);
776 }
777
795 public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
796 List<String> images = new ArrayList<>();
797 for (String imageURL : imageURLs) {
798 images.add(encodeByteArrayToBase64(Utils.loadImageBytesFromUrl(imageURL)));
799 }
800 OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt, images);
801 ollamaRequestModel.setOptions(options.getOptionsMap());
802 return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
803 }
804
815 public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs, Options options) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
816 return generateWithImageURLs(model, prompt, imageURLs, options, null);
817 }
818
833 public OllamaChatResult chat(String model, List<OllamaChatMessage> messages) throws OllamaBaseException, IOException, InterruptedException {
835 return chat(builder.withMessages(messages).build());
836 }
837
852 public OllamaChatResult chat(OllamaChatRequest request) throws OllamaBaseException, IOException, InterruptedException {
853 return chat(request, null);
854 }
855
871 public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
872 return chatStreaming(request, new OllamaChatStreamObserver(streamHandler));
873 }
874
890 public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) throws OllamaBaseException, IOException, InterruptedException {
891 OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
892 OllamaChatResult result;
893
894 // add all registered tools to Request
895 request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
896
897 if (tokenHandler != null) {
898 request.setStream(true);
899 result = requestCaller.call(request, tokenHandler);
900 } else {
901 result = requestCaller.callSync(request);
902 }
903
904 // check if toolCallIsWanted
905 List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls();
906 int toolCallTries = 0;
907 while (toolCalls != null && !toolCalls.isEmpty() && toolCallTries < maxChatToolCallRetries) {
908 for (OllamaChatToolCalls toolCall : toolCalls) {
909 String toolName = toolCall.getFunction().getName();
910 ToolFunction toolFunction = toolRegistry.getToolFunction(toolName);
911 Map<String, Object> arguments = toolCall.getFunction().getArguments();
912 Object res = toolFunction.apply(arguments);
913 request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL, "[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() + ") : " + res + "[/TOOL_RESULTS]"));
914 }
915
916 if (tokenHandler != null) {
917 result = requestCaller.call(request, tokenHandler);
918 } else {
919 result = requestCaller.callSync(request);
920 }
921 toolCalls = result.getResponseModel().getMessage().getToolCalls();
922 toolCallTries++;
923 }
924
925 return result;
926 }
927
934 public void registerTool(Tools.ToolSpecification toolSpecification) {
935 toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
936 if (this.verbose) {
937 logger.debug("Registered tool: {}", toolSpecification.getFunctionName());
938 }
939 }
940
948 public void registerTools(List<Tools.ToolSpecification> toolSpecifications) {
949 for (Tools.ToolSpecification toolSpecification : toolSpecifications) {
950 toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
951 }
952 }
953
963 try {
964 Class<?> callerClass = null;
965 try {
966 callerClass = Class.forName(Thread.currentThread().getStackTrace()[2].getClassName());
967 } catch (ClassNotFoundException e) {
968 throw new RuntimeException(e);
969 }
970
971 OllamaToolService ollamaToolServiceAnnotation = callerClass.getDeclaredAnnotation(OllamaToolService.class);
972 if (ollamaToolServiceAnnotation == null) {
973 throw new IllegalStateException(callerClass + " is not annotated as " + OllamaToolService.class);
974 }
975
976 Class<?>[] providers = ollamaToolServiceAnnotation.providers();
977 for (Class<?> provider : providers) {
978 registerAnnotatedTools(provider.getDeclaredConstructor().newInstance());
979 }
980 } catch (InstantiationException | NoSuchMethodException | IllegalAccessException |
981 InvocationTargetException e) {
982 throw new RuntimeException(e);
983 }
984 }
985
994 public void registerAnnotatedTools(Object object) {
995 Class<?> objectClass = object.getClass();
996 Method[] methods = objectClass.getMethods();
997 for (Method m : methods) {
998 ToolSpec toolSpec = m.getDeclaredAnnotation(ToolSpec.class);
999 if (toolSpec == null) {
1000 continue;
1001 }
1002 String operationName = !toolSpec.name().isBlank() ? toolSpec.name() : m.getName();
1003 String operationDesc = !toolSpec.desc().isBlank() ? toolSpec.desc() : operationName;
1004
1005 final Tools.PropsBuilder propsBuilder = new Tools.PropsBuilder();
1006 LinkedHashMap<String, String> methodParams = new LinkedHashMap<>();
1007 for (Parameter parameter : m.getParameters()) {
1008 final ToolProperty toolPropertyAnn = parameter.getDeclaredAnnotation(ToolProperty.class);
1009 String propType = parameter.getType().getTypeName();
1010 if (toolPropertyAnn == null) {
1011 methodParams.put(parameter.getName(), null);
1012 continue;
1013 }
1014 String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName();
1015 methodParams.put(propName, propType);
1016 propsBuilder.withProperty(propName, Tools.PromptFuncDefinition.Property.builder().type(propType).description(toolPropertyAnn.desc()).required(toolPropertyAnn.required()).build());
1017 }
1018 final Map<String, Tools.PromptFuncDefinition.Property> params = propsBuilder.build();
1019 List<String> reqProps = params.entrySet().stream().filter(e -> e.getValue().isRequired()).map(Map.Entry::getKey).collect(Collectors.toList());
1020
1021 Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder().functionName(operationName).functionDescription(operationDesc).toolPrompt(Tools.PromptFuncDefinition.builder().type("function").function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name(operationName).description(operationDesc).parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object").properties(params).required(reqProps).build()).build()).build()).build();
1022
1023 ReflectionalToolFunction reflectionalToolFunction = new ReflectionalToolFunction(object, m, methodParams);
1024 toolSpecification.setToolFunction(reflectionalToolFunction);
1025 toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
1026 }
1027
1028 }
1029
1036 public OllamaChatMessageRole addCustomRole(String roleName) {
1037 return OllamaChatMessageRole.newCustomRole(roleName);
1038 }
1039
1045 public List<OllamaChatMessageRole> listRoles() {
1047 }
1048
1056 public OllamaChatMessageRole getRole(String roleName) throws RoleNotFoundException {
1057 return OllamaChatMessageRole.getRole(roleName);
1058 }
1059
1060
1061 // technical private methods //
1062
1070 private static String encodeFileToBase64(File file) throws IOException {
1071 return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
1072 }
1073
1080 private static String encodeByteArrayToBase64(byte[] bytes) {
1081 return Base64.getEncoder().encodeToString(bytes);
1082 }
1083
1096 private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
1097 OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
1098 OllamaResult result;
1099 if (streamHandler != null) {
1100 ollamaRequestModel.setStream(true);
1101 result = requestCaller.call(ollamaRequestModel, streamHandler);
1102 } else {
1103 result = requestCaller.callSync(ollamaRequestModel);
1104 }
1105 return result;
1106 }
1107
1108
1115 private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
1116 HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header("Content-Type", "application/json").timeout(Duration.ofSeconds(requestTimeoutSeconds));
1117 if (isBasicAuthCredentialsSet()) {
1118 requestBuilder.header("Authorization", getBasicAuthHeaderValue());
1119 }
1120 return requestBuilder;
1121 }
1122
1128 private String getBasicAuthHeaderValue() {
1129 String credentialsToEncode = basicAuth.getUsername() + ":" + basicAuth.getPassword();
1130 return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
1131 }
1132
1138 private boolean isBasicAuthCredentialsSet() {
1139 return basicAuth != null;
1140 }
1141
1142 private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec) throws ToolInvocationException {
1143 try {
1144 String methodName = toolFunctionCallSpec.getName();
1145 Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
1146 ToolFunction function = toolRegistry.getToolFunction(methodName);
1147 if (verbose) {
1148 logger.debug("Invoking function {} with arguments {}", methodName, arguments);
1149 }
1150 if (function == null) {
1151 throw new ToolNotFoundException("No such tool: " + methodName + ". Please register the tool before invoking it.");
1152 }
1153 return function.apply(arguments);
1154 } catch (Exception e) {
1155 throw new ToolInvocationException("Failed to invoke tool: " + toolFunctionCallSpec.getName(), e);
1156 }
1157 }
1158}
OllamaResult generateWithImageFiles(String model, String prompt, List< File > imageFiles, Options options, OllamaStreamHandler streamHandler)
OllamaChatResult chat(OllamaChatRequest request)
ModelDetail getModelDetails(String modelName)
List< Double > generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest)
OllamaAsyncResultStreamer generateAsync(String model, String prompt, boolean raw)
OllamaChatResult chat(String model, List< OllamaChatMessage > messages)
List< Double > generateEmbeddings(String model, String prompt)
void pullModel(String modelName)
void setBasicAuth(String username, String password)
OllamaResult generate(String model, String prompt, boolean raw, Options options, OllamaStreamHandler streamHandler)
OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler)
void deleteModel(String modelName, boolean ignoreIfNotPresent)
void createModelWithFilePath(String modelName, String modelFilePath)
OllamaResult generate(String model, String prompt, boolean raw, Options options)
ModelsProcessResponse ps()
LibraryModelTag findModelTagFromLibrary(String modelName, String tag)
OllamaChatMessageRole getRole(String roleName)
OllamaChatMessageRole addCustomRole(String roleName)
void pullModel(LibraryModelTag libraryModelTag)
List< OllamaChatMessageRole > listRoles()
LibraryModelDetail getLibraryModelDetails(LibraryModel libraryModel)
OllamaResult generateWithImageURLs(String model, String prompt, List< String > imageURLs, Options options, OllamaStreamHandler streamHandler)
OllamaEmbedResponseModel embed(OllamaEmbedRequestModel modelRequest)
void registerTool(Tools.ToolSpecification toolSpecification)
void registerAnnotatedTools(Object object)
List< LibraryModel > listModelsFromLibrary()
void registerTools(List< Tools.ToolSpecification > toolSpecifications)
OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler)
OllamaToolsResult generateWithTools(String model, String prompt, Options options)
void createModel(CustomModelRequest customModelRequest)
OllamaResult generateWithImageURLs(String model, String prompt, List< String > imageURLs, Options options)
OllamaEmbedResponseModel embed(String model, List< String > inputs)
void createModelWithModelFileContents(String modelName, String modelFileContents)
OllamaResult generateWithImageFiles(String model, String prompt, List< File > imageFiles, Options options)
static OllamaChatMessageRole newCustomRole(String roleName)
static OllamaChatMessageRole getRole(String roleName)
OllamaChatRequestBuilder withMessages(List< OllamaChatMessage > messages)
static OllamaChatRequestBuilder getInstance(String model)
OllamaChatResult call(OllamaChatRequest body, OllamaTokenHandler tokenHandler)
OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
ToolFunction getToolFunction(String name)
static byte[] loadImageBytesFromUrl(String imageUrl)
Definition Utils.java:25
static ObjectMapper getObjectMapper()
Definition Utils.java:17
Object apply(Map< String, Object > arguments)