1package io.github.ollama4j;
3import io.github.ollama4j.exceptions.OllamaBaseException;
4import io.github.ollama4j.exceptions.ToolInvocationException;
5import io.github.ollama4j.exceptions.ToolNotFoundException;
6import io.github.ollama4j.models.chat.OllamaChatMessage;
7import io.github.ollama4j.models.chat.OllamaChatRequest;
8import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
9import io.github.ollama4j.models.chat.OllamaChatResult;
10import io.github.ollama4j.models.embeddings.OllamaEmbeddingResponseModel;
11import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
12import io.github.ollama4j.models.generate.OllamaGenerateRequest;
13import io.github.ollama4j.models.generate.OllamaStreamHandler;
14import io.github.ollama4j.models.ps.ModelsProcessResponse;
15import io.github.ollama4j.models.request.*;
16import io.github.ollama4j.models.response.*;
17import io.github.ollama4j.tools.*;
18import io.github.ollama4j.utils.Options;
19import io.github.ollama4j.utils.Utils;
21import org.slf4j.Logger;
22import org.slf4j.LoggerFactory;
26import java.net.URISyntaxException;
27import java.net.http.HttpClient;
28import java.net.http.HttpConnectTimeoutException;
29import java.net.http.HttpRequest;
30import java.net.http.HttpResponse;
31import java.nio.charset.StandardCharsets;
32import java.nio.file.Files;
33import java.time.Duration;
39@SuppressWarnings(
"DuplicatedCode")
42 private static final Logger logger = LoggerFactory.getLogger(
OllamaAPI.class);
43 private final String host;
49 private long requestTimeoutSeconds = 10;
55 private boolean verbose =
true;
64 this.host =
"http://localhost:11434";
73 if (host.endsWith(
"/")) {
74 this.host = host.substring(0, host.length() - 1);
87 this.basicAuth =
new BasicAuth(username, password);
96 String url = this.host +
"/api/tags";
97 HttpClient httpClient = HttpClient.newHttpClient();
98 HttpRequest httpRequest =
null;
101 getRequestBuilderDefault(
new URI(url))
102 .header(
"Accept",
"application/json")
103 .header(
"Content-type",
"application/json")
106 }
catch (URISyntaxException e) {
107 throw new RuntimeException(e);
109 HttpResponse<String> response =
null;
111 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
112 }
catch (HttpConnectTimeoutException e) {
114 }
catch (IOException | InterruptedException e) {
115 throw new RuntimeException(e);
117 int statusCode = response.statusCode();
118 return statusCode == 200;
127 String url = this.host +
"/api/ps";
128 HttpClient httpClient = HttpClient.newHttpClient();
129 HttpRequest httpRequest =
null;
132 getRequestBuilderDefault(
new URI(url))
133 .header(
"Accept",
"application/json")
134 .header(
"Content-type",
"application/json")
137 }
catch (URISyntaxException e) {
138 throw new RuntimeException(e);
140 HttpResponse<String> response =
null;
141 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
142 int statusCode = response.statusCode();
143 String responseString = response.body();
144 if (statusCode == 200) {
159 String url = this.host +
"/api/tags";
160 HttpClient httpClient = HttpClient.newHttpClient();
161 HttpRequest httpRequest =
162 getRequestBuilderDefault(
new URI(url))
163 .header(
"Accept",
"application/json")
164 .header(
"Content-type",
"application/json")
167 HttpResponse<String> response =
168 httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
169 int statusCode = response.statusCode();
170 String responseString = response.body();
171 if (statusCode == 200) {
188 String url = this.host +
"/api/pull";
190 HttpRequest request =
191 getRequestBuilderDefault(
new URI(url))
192 .POST(HttpRequest.BodyPublishers.ofString(jsonData))
193 .header(
"Accept",
"application/json")
194 .header(
"Content-type",
"application/json")
196 HttpClient client = HttpClient.newHttpClient();
197 HttpResponse<InputStream> response =
198 client.send(request, HttpResponse.BodyHandlers.ofInputStream());
199 int statusCode = response.statusCode();
200 InputStream responseBodyStream = response.body();
201 String responseString =
"";
202 try (BufferedReader reader =
203 new BufferedReader(
new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
205 while ((line = reader.readLine()) !=
null) {
209 logger.info(modelPullResponse.getStatus());
213 if (statusCode != 200) {
226 String url = this.host +
"/api/show";
228 HttpRequest request =
229 getRequestBuilderDefault(
new URI(url))
230 .header(
"Accept",
"application/json")
231 .header(
"Content-type",
"application/json")
232 .POST(HttpRequest.BodyPublishers.ofString(jsonData))
234 HttpClient client = HttpClient.newHttpClient();
235 HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
236 int statusCode = response.statusCode();
237 String responseBody = response.body();
238 if (statusCode == 200) {
254 String url = this.host +
"/api/create";
256 HttpRequest request =
257 getRequestBuilderDefault(
new URI(url))
258 .header(
"Accept",
"application/json")
259 .header(
"Content-Type",
"application/json")
260 .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
262 HttpClient client = HttpClient.newHttpClient();
263 HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
264 int statusCode = response.statusCode();
265 String responseString = response.body();
266 if (statusCode != 200) {
271 if (responseString.contains(
"error")) {
275 logger.info(responseString);
288 String url = this.host +
"/api/create";
290 HttpRequest request =
291 getRequestBuilderDefault(
new URI(url))
292 .header(
"Accept",
"application/json")
293 .header(
"Content-Type",
"application/json")
294 .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
296 HttpClient client = HttpClient.newHttpClient();
297 HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
298 int statusCode = response.statusCode();
299 String responseString = response.body();
300 if (statusCode != 200) {
303 if (responseString.contains(
"error")) {
307 logger.info(responseString);
317 public void deleteModel(String modelName,
boolean ignoreIfNotPresent)
319 String url = this.host +
"/api/delete";
321 HttpRequest request =
322 getRequestBuilderDefault(
new URI(url))
323 .method(
"DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
324 .header(
"Accept",
"application/json")
325 .header(
"Content-type",
"application/json")
327 HttpClient client = HttpClient.newHttpClient();
328 HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
329 int statusCode = response.statusCode();
330 String responseBody = response.body();
331 if (statusCode == 404 && responseBody.contains(
"model") && responseBody.contains(
"not found")) {
334 if (statusCode != 200) {
358 URI uri = URI.create(this.host +
"/api/embeddings");
359 String jsonData = modelRequest.toString();
360 HttpClient httpClient = HttpClient.newHttpClient();
361 HttpRequest.Builder requestBuilder =
362 getRequestBuilderDefault(uri)
363 .header(
"Accept",
"application/json")
364 .POST(HttpRequest.BodyPublishers.ofString(jsonData));
365 HttpRequest request = requestBuilder.build();
366 HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
367 int statusCode = response.statusCode();
368 String responseBody = response.body();
369 if (statusCode == 200) {
372 return embeddingResponse.getEmbedding();
394 ollamaRequestModel.setRaw(raw);
395 ollamaRequestModel.setOptions(options.getOptionsMap());
396 return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
412 return generate(model, prompt, raw, options,
null);
433 Map<ToolFunctionCallSpec, Object> toolResults =
new HashMap<>();
435 OllamaResult result = generate(model, prompt, raw, options,
null);
436 toolResult.setModelResult(result);
438 String toolsResponse = result.getResponse();
439 if (toolsResponse.contains(
"[TOOL_CALLS]")) {
440 toolsResponse = toolsResponse.replace(
"[TOOL_CALLS]",
"");
445 toolResults.put(toolFunctionCallSpec, invokeTool(toolFunctionCallSpec));
447 toolResult.setToolResults(toolResults);
463 ollamaRequestModel.setRaw(raw);
464 URI uri = URI.create(this.host +
"/api/generate");
467 getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds);
468 ollamaAsyncResultStreamer.start();
469 return ollamaAsyncResultStreamer;
488 List<String> images =
new ArrayList<>();
489 for (File imageFile : imageFiles) {
490 images.add(encodeFileToBase64(imageFile));
493 ollamaRequestModel.setOptions(options.getOptionsMap());
494 return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
503 String model, String prompt, List<File> imageFiles,
Options options)
505 return generateWithImageFiles(model, prompt, imageFiles, options,
null);
524 List<String> images =
new ArrayList<>();
525 for (String imageURL : imageURLs) {
529 ollamaRequestModel.setOptions(options.getOptionsMap());
530 return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
541 return generateWithImageURLs(model, prompt, imageURLs, options,
null);
573 return chat(request,
null);
591 if (streamHandler !=
null) {
592 request.setStream(
true);
593 result = requestCaller.
call(request, streamHandler);
595 result = requestCaller.
callSync(request);
597 return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
601 toolRegistry.
addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
606 private static String encodeFileToBase64(File file)
throws IOException {
607 return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
610 private static String encodeByteArrayToBase64(
byte[] bytes) {
611 return Base64.getEncoder().encodeToString(bytes);
620 if (streamHandler !=
null) {
621 ollamaRequestModel.setStream(
true);
622 result = requestCaller.
call(ollamaRequestModel, streamHandler);
624 result = requestCaller.
callSync(ollamaRequestModel);
635 private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
636 HttpRequest.Builder requestBuilder =
637 HttpRequest.newBuilder(uri)
638 .header(
"Content-Type",
"application/json")
639 .timeout(Duration.ofSeconds(requestTimeoutSeconds));
640 if (isBasicAuthCredentialsSet()) {
641 requestBuilder.header(
"Authorization", getBasicAuthHeaderValue());
643 return requestBuilder;
651 private String getBasicAuthHeaderValue() {
652 String credentialsToEncode = basicAuth.getUsername() +
":" + basicAuth.getPassword();
653 return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
661 private boolean isBasicAuthCredentialsSet() {
662 return basicAuth !=
null;
666 private Object invokeTool(
ToolFunctionCallSpec toolFunctionCallSpec)
throws ToolInvocationException {
668 String methodName = toolFunctionCallSpec.getName();
669 Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
672 logger.debug(
"Invoking function {} with arguments {}", methodName, arguments);
674 if (
function ==
null) {
675 throw new ToolNotFoundException(
"No such tool: " + methodName);
677 return function.
apply(arguments);
678 }
catch (Exception e) {
679 throw new ToolInvocationException(
"Failed to invoke tool: " + toolFunctionCallSpec.getName(), e);
OllamaResult generateWithImageFiles(String model, String prompt, List< File > imageFiles, Options options, OllamaStreamHandler streamHandler)
OllamaChatResult chat(OllamaChatRequest request)
ModelDetail getModelDetails(String modelName)
List< Double > generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest)
List< Model > listModels()
OllamaAsyncResultStreamer generateAsync(String model, String prompt, boolean raw)
OllamaChatResult chat(String model, List< OllamaChatMessage > messages)
List< Double > generateEmbeddings(String model, String prompt)
void pullModel(String modelName)
void setBasicAuth(String username, String password)
OllamaResult generate(String model, String prompt, boolean raw, Options options, OllamaStreamHandler streamHandler)
OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler)
void deleteModel(String modelName, boolean ignoreIfNotPresent)
void createModelWithFilePath(String modelName, String modelFilePath)
OllamaResult generate(String model, String prompt, boolean raw, Options options)
ModelsProcessResponse ps()
OllamaResult generateWithImageURLs(String model, String prompt, List< String > imageURLs, Options options, OllamaStreamHandler streamHandler)
void registerTool(Tools.ToolSpecification toolSpecification)
OllamaToolsResult generateWithTools(String model, String prompt, Options options)
OllamaResult generateWithImageURLs(String model, String prompt, List< String > imageURLs, Options options)
void createModelWithModelFileContents(String modelName, String modelFileContents)
OllamaResult generateWithImageFiles(String model, String prompt, List< File > imageFiles, Options options)
OllamaChatRequestBuilder withMessages(List< OllamaChatMessage > messages)
OllamaChatRequest build()
static OllamaChatRequestBuilder getInstance(String model)
OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
OllamaResult callSync(OllamaRequestBody body)
OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
static byte[] loadImageBytesFromUrl(String imageUrl)
static ObjectMapper getObjectMapper()