1package io.github.ollama4j;
3import io.github.ollama4j.exceptions.OllamaBaseException;
4import io.github.ollama4j.exceptions.RoleNotFoundException;
5import io.github.ollama4j.exceptions.ToolInvocationException;
6import io.github.ollama4j.exceptions.ToolNotFoundException;
7import io.github.ollama4j.models.chat.*;
8import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
9import io.github.ollama4j.models.embeddings.OllamaEmbeddingResponseModel;
10import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
11import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
12import io.github.ollama4j.models.generate.OllamaGenerateRequest;
13import io.github.ollama4j.models.generate.OllamaStreamHandler;
14import io.github.ollama4j.models.ps.ModelsProcessResponse;
15import io.github.ollama4j.models.request.*;
16import io.github.ollama4j.models.response.*;
17import io.github.ollama4j.tools.*;
18import io.github.ollama4j.utils.Options;
19import io.github.ollama4j.utils.Utils;
24import java.net.URISyntaxException;
25import java.net.http.HttpClient;
26import java.net.http.HttpConnectTimeoutException;
27import java.net.http.HttpRequest;
28import java.net.http.HttpResponse;
29import java.nio.charset.StandardCharsets;
30import java.nio.file.Files;
31import java.time.Duration;
33import java.util.stream.Collectors;
35import org.slf4j.Logger;
36import org.slf4j.LoggerFactory;
37import org.jsoup.Jsoup;
38import org.jsoup.nodes.Document;
39import org.jsoup.nodes.Element;
40import org.jsoup.select.Elements;
45@SuppressWarnings({
"DuplicatedCode",
"resource"})
48 private static final Logger logger = LoggerFactory.getLogger(
OllamaAPI.class);
49 private final String host;
55 private long requestTimeoutSeconds = 10;
61 private boolean verbose =
true;
70 this.host =
"http://localhost:11434";
79 if (host.endsWith(
"/")) {
80 this.host = host.substring(0, host.length() - 1);
93 this.basicAuth =
new BasicAuth(username, password);
102 String url = this.host +
"/api/tags";
103 HttpClient httpClient = HttpClient.newHttpClient();
104 HttpRequest httpRequest =
null;
106 httpRequest = getRequestBuilderDefault(
new URI(url)).header(
"Accept",
"application/json").header(
"Content-type",
"application/json").GET().build();
107 }
catch (URISyntaxException e) {
108 throw new RuntimeException(e);
110 HttpResponse<String> response =
null;
112 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
113 }
catch (HttpConnectTimeoutException e) {
115 }
catch (IOException | InterruptedException e) {
116 throw new RuntimeException(e);
118 int statusCode = response.statusCode();
119 return statusCode == 200;
131 String url = this.host +
"/api/ps";
132 HttpClient httpClient = HttpClient.newHttpClient();
133 HttpRequest httpRequest =
null;
135 httpRequest = getRequestBuilderDefault(
new URI(url)).header(
"Accept",
"application/json").header(
"Content-type",
"application/json").GET().build();
136 }
catch (URISyntaxException e) {
137 throw new RuntimeException(e);
139 HttpResponse<String> response =
null;
140 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
141 int statusCode = response.statusCode();
142 String responseString = response.body();
143 if (statusCode == 200) {
160 String url = this.host +
"/api/tags";
161 HttpClient httpClient = HttpClient.newHttpClient();
162 HttpRequest httpRequest = getRequestBuilderDefault(
new URI(url)).header(
"Accept",
"application/json").header(
"Content-type",
"application/json").GET().build();
163 HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
164 int statusCode = response.statusCode();
165 String responseString = response.body();
166 if (statusCode == 200) {
184 String url =
"https://ollama.com/library";
185 HttpClient httpClient = HttpClient.newHttpClient();
186 HttpRequest httpRequest = getRequestBuilderDefault(
new URI(url)).header(
"Accept",
"application/json").header(
"Content-type",
"application/json").GET().build();
187 HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
188 int statusCode = response.statusCode();
189 String responseString = response.body();
190 List<LibraryModel> models =
new ArrayList<>();
191 if (statusCode == 200) {
192 Document doc = Jsoup.parse(responseString);
193 Elements modelSections = doc.selectXpath(
"//*[@id='repo']/ul/li/a");
194 for (Element e : modelSections) {
196 Elements names = e.select(
"div > h2 > div > span");
197 Elements desc = e.select(
"div > p");
198 Elements pullCounts = e.select(
"div:nth-of-type(2) > p > span:first-of-type > span:first-of-type");
199 Elements popularTags = e.select(
"div > div > span");
200 Elements totalTags = e.select(
"div:nth-of-type(2) > p > span:nth-of-type(2) > span:first-of-type");
201 Elements lastUpdatedTime = e.select(
"div:nth-of-type(2) > p > span:nth-of-type(3) > span:nth-of-type(2)");
203 if (names.first() ==
null || names.isEmpty()) {
207 Optional.ofNullable(names.first()).map(Element::text).ifPresent(model::setName);
208 model.setDescription(Optional.ofNullable(desc.first()).map(Element::text).orElse(
""));
209 model.setPopularTags(Optional.of(popularTags).map(tags -> tags.stream().map(Element::text).collect(Collectors.toList())).orElse(
new ArrayList<>()));
210 model.setPullCount(Optional.ofNullable(pullCounts.first()).map(Element::text).orElse(
""));
211 model.setTotalTags(Optional.ofNullable(totalTags.first()).map(Element::text).map(Integer::parseInt).orElse(0));
212 model.setLastUpdated(Optional.ofNullable(lastUpdatedTime.first()).map(Element::text).orElse(
""));
237 String url = String.format(
"https://ollama.com/library/%s/tags", libraryModel.getName());
238 HttpClient httpClient = HttpClient.newHttpClient();
239 HttpRequest httpRequest = getRequestBuilderDefault(
new URI(url)).header(
"Accept",
"application/json").header(
"Content-type",
"application/json").GET().build();
240 HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
241 int statusCode = response.statusCode();
242 String responseString = response.body();
244 List<LibraryModelTag> libraryModelTags =
new ArrayList<>();
245 if (statusCode == 200) {
246 Document doc = Jsoup.parse(responseString);
247 Elements tagSections = doc.select(
"html > body > main > div > section > div > div > div:nth-child(n+2) > div");
248 for (Element e : tagSections) {
249 Elements tags = e.select(
"div > a > div");
250 Elements tagsMetas = e.select(
"div > span");
254 if (tags.first() ==
null || tags.isEmpty()) {
258 libraryModelTag.setName(libraryModel.getName());
259 Optional.ofNullable(tags.first()).map(Element::text).ifPresent(libraryModelTag::setTag);
260 libraryModelTag.setSize(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split(
"•")).filter(parts -> parts.length > 1).map(parts -> parts[1].trim()).orElse(
""));
261 libraryModelTag.setLastUpdated(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split(
"•")).filter(parts -> parts.length > 1).map(parts -> parts[2].trim()).orElse(
""));
262 libraryModelTags.add(libraryModelTag);
265 libraryModelDetail.setModel(libraryModel);
266 libraryModelDetail.setTags(libraryModelTags);
267 return libraryModelDetail;
290 List<LibraryModel> libraryModels = this.listModelsFromLibrary();
291 LibraryModel libraryModel = libraryModels.stream().filter(model -> model.getName().equals(modelName)).findFirst().orElseThrow(() ->
new NoSuchElementException(String.format(
"Model by name '%s' not found", modelName)));
293 LibraryModelTag libraryModelTag = libraryModelDetail.getTags().stream().filter(tagName -> tagName.getTag().equals(tag)).findFirst().orElseThrow(() ->
new NoSuchElementException(String.format(
"Tag '%s' for model '%s' not found", tag, modelName)));
294 return libraryModelTag;
308 String url = this.host +
"/api/pull";
310 HttpRequest request = getRequestBuilderDefault(
new URI(url)).POST(HttpRequest.BodyPublishers.ofString(jsonData)).header(
"Accept",
"application/json").header(
"Content-type",
"application/json").build();
311 HttpClient client = HttpClient.newHttpClient();
312 HttpResponse<InputStream> response = client.send(request, HttpResponse.BodyHandlers.ofInputStream());
313 int statusCode = response.statusCode();
314 InputStream responseBodyStream = response.body();
315 String responseString =
"";
316 try (BufferedReader reader =
new BufferedReader(
new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
318 while ((line = reader.readLine()) !=
null) {
321 logger.info(modelPullResponse.getStatus());
325 if (statusCode != 200) {
343 String tagToPull = String.format(
"%s:%s", libraryModelTag.getName(), libraryModelTag.getTag());
344 pullModel(tagToPull);
358 String url = this.host +
"/api/show";
360 HttpRequest request = getRequestBuilderDefault(
new URI(url)).header(
"Accept",
"application/json").header(
"Content-type",
"application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
361 HttpClient client = HttpClient.newHttpClient();
362 HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
363 int statusCode = response.statusCode();
364 String responseBody = response.body();
365 if (statusCode == 200) {
384 String url = this.host +
"/api/create";
386 HttpRequest request = getRequestBuilderDefault(
new URI(url)).header(
"Accept",
"application/json").header(
"Content-Type",
"application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
387 HttpClient client = HttpClient.newHttpClient();
388 HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
389 int statusCode = response.statusCode();
390 String responseString = response.body();
391 if (statusCode != 200) {
396 if (responseString.contains(
"error")) {
400 logger.info(responseString);
416 String url = this.host +
"/api/create";
418 HttpRequest request = getRequestBuilderDefault(
new URI(url)).header(
"Accept",
"application/json").header(
"Content-Type",
"application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
419 HttpClient client = HttpClient.newHttpClient();
420 HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
421 int statusCode = response.statusCode();
422 String responseString = response.body();
423 if (statusCode != 200) {
426 if (responseString.contains(
"error")) {
430 logger.info(responseString);
445 String url = this.host +
"/api/delete";
447 HttpRequest request = getRequestBuilderDefault(
new URI(url)).method(
"DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).header(
"Accept",
"application/json").header(
"Content-type",
"application/json").build();
448 HttpClient client = HttpClient.newHttpClient();
449 HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
450 int statusCode = response.statusCode();
451 String responseBody = response.body();
452 if (statusCode == 404 && responseBody.contains(
"model") && responseBody.contains(
"not found")) {
455 if (statusCode != 200) {
488 URI uri = URI.create(this.host +
"/api/embeddings");
489 String jsonData = modelRequest.toString();
490 HttpClient httpClient = HttpClient.newHttpClient();
491 HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).header(
"Accept",
"application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData));
492 HttpRequest request = requestBuilder.build();
493 HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
494 int statusCode = response.statusCode();
495 String responseBody = response.body();
496 if (statusCode == 200) {
498 return embeddingResponse.getEmbedding();
528 URI uri = URI.create(this.host +
"/api/embed");
530 HttpClient httpClient = HttpClient.newHttpClient();
532 HttpRequest request = HttpRequest.newBuilder(uri).header(
"Accept",
"application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
534 HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
535 int statusCode = response.statusCode();
536 String responseBody = response.body();
538 if (statusCode == 200) {
562 ollamaRequestModel.setRaw(raw);
563 ollamaRequestModel.setOptions(options.getOptionsMap());
564 return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
582 return generate(model, prompt, raw, options,
null);
600 Map<ToolFunctionCallSpec, Object> toolResults =
new HashMap<>();
602 OllamaResult result = generate(model, prompt, raw, options,
null);
603 toolResult.setModelResult(result);
605 String toolsResponse = result.getResponse();
606 if (toolsResponse.contains(
"[TOOL_CALLS]")) {
607 toolsResponse = toolsResponse.replace(
"[TOOL_CALLS]",
"");
612 toolResults.put(toolFunctionCallSpec, invokeTool(toolFunctionCallSpec));
614 toolResult.setToolResults(toolResults);
629 ollamaRequestModel.setRaw(raw);
630 URI uri = URI.create(this.host +
"/api/generate");
632 ollamaAsyncResultStreamer.start();
633 return ollamaAsyncResultStreamer;
653 List<String> images =
new ArrayList<>();
654 for (File imageFile : imageFiles) {
655 images.add(encodeFileToBase64(imageFile));
658 ollamaRequestModel.setOptions(options.getOptionsMap());
659 return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
672 return generateWithImageFiles(model, prompt, imageFiles, options,
null);
693 List<String> images =
new ArrayList<>();
694 for (String imageURL : imageURLs) {
698 ollamaRequestModel.setOptions(options.getOptionsMap());
699 return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
713 return generateWithImageURLs(model, prompt, imageURLs, options,
null);
750 return chat(request,
null);
771 if (streamHandler !=
null) {
772 request.setStream(
true);
773 result = requestCaller.
call(request, streamHandler);
775 result = requestCaller.
callSync(request);
777 return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
781 toolRegistry.
addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
817 private static String encodeFileToBase64(File file)
throws IOException {
818 return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
821 private static String encodeByteArrayToBase64(
byte[] bytes) {
822 return Base64.getEncoder().encodeToString(bytes);
828 if (streamHandler !=
null) {
829 ollamaRequestModel.setStream(
true);
830 result = requestCaller.
call(ollamaRequestModel, streamHandler);
832 result = requestCaller.
callSync(ollamaRequestModel);
843 private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
844 HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header(
"Content-Type",
"application/json").timeout(Duration.ofSeconds(requestTimeoutSeconds));
845 if (isBasicAuthCredentialsSet()) {
846 requestBuilder.header(
"Authorization", getBasicAuthHeaderValue());
848 return requestBuilder;
856 private String getBasicAuthHeaderValue() {
857 String credentialsToEncode = basicAuth.getUsername() +
":" + basicAuth.getPassword();
858 return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
866 private boolean isBasicAuthCredentialsSet() {
867 return basicAuth !=
null;
870 private Object invokeTool(
ToolFunctionCallSpec toolFunctionCallSpec)
throws ToolInvocationException {
872 String methodName = toolFunctionCallSpec.getName();
873 Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
876 logger.debug(
"Invoking function {} with arguments {}", methodName, arguments);
878 if (
function ==
null) {
879 throw new ToolNotFoundException(
"No such tool: " + methodName);
881 return function.
apply(arguments);
882 }
catch (Exception e) {
883 throw new ToolInvocationException(
"Failed to invoke tool: " + toolFunctionCallSpec.getName(), e);
OllamaResult generateWithImageFiles(String model, String prompt, List< File > imageFiles, Options options, OllamaStreamHandler streamHandler)
OllamaChatResult chat(OllamaChatRequest request)
ModelDetail getModelDetails(String modelName)
List< Double > generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest)
List< Model > listModels()
OllamaAsyncResultStreamer generateAsync(String model, String prompt, boolean raw)
OllamaChatResult chat(String model, List< OllamaChatMessage > messages)
List< Double > generateEmbeddings(String model, String prompt)
void pullModel(String modelName)
void setBasicAuth(String username, String password)
OllamaResult generate(String model, String prompt, boolean raw, Options options, OllamaStreamHandler streamHandler)
OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler)
void deleteModel(String modelName, boolean ignoreIfNotPresent)
void createModelWithFilePath(String modelName, String modelFilePath)
OllamaResult generate(String model, String prompt, boolean raw, Options options)
ModelsProcessResponse ps()
LibraryModelTag findModelTagFromLibrary(String modelName, String tag)
OllamaChatMessageRole getRole(String roleName)
OllamaChatMessageRole addCustomRole(String roleName)
void pullModel(LibraryModelTag libraryModelTag)
List< OllamaChatMessageRole > listRoles()
LibraryModelDetail getLibraryModelDetails(LibraryModel libraryModel)
OllamaResult generateWithImageURLs(String model, String prompt, List< String > imageURLs, Options options, OllamaStreamHandler streamHandler)
OllamaEmbedResponseModel embed(OllamaEmbedRequestModel modelRequest)
void registerTool(Tools.ToolSpecification toolSpecification)
List< LibraryModel > listModelsFromLibrary()
OllamaToolsResult generateWithTools(String model, String prompt, Options options)
OllamaResult generateWithImageURLs(String model, String prompt, List< String > imageURLs, Options options)
OllamaEmbedResponseModel embed(String model, List< String > inputs)
void createModelWithModelFileContents(String modelName, String modelFileContents)
OllamaResult generateWithImageFiles(String model, String prompt, List< File > imageFiles, Options options)
static List< OllamaChatMessageRole > getRoles()
static OllamaChatMessageRole newCustomRole(String roleName)
static OllamaChatMessageRole getRole(String roleName)
OllamaChatRequestBuilder withMessages(List< OllamaChatMessage > messages)
OllamaChatRequest build()
static OllamaChatRequestBuilder getInstance(String model)
OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
OllamaResult callSync(OllamaRequestBody body)
OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
static byte[] loadImageBytesFromUrl(String imageUrl)
static ObjectMapper getObjectMapper()