58 private static final Logger LOG = LoggerFactory.getLogger(
Ollama.class);
60 private final String host;
71 @Setter
private long requestTimeoutSeconds = 10;
76 @Setter
private int imageURLReadTimeoutSeconds = 10;
81 @Setter
private int imageURLConnectTimeoutSeconds = 10;
89 @Setter
private int maxChatToolCallRetries = 3;
100 @SuppressWarnings({
"FieldMayBeFinal",
"FieldCanBeLocal"})
101 private int numberOfRetriesForModelPull = 0;
109 @Setter
private boolean metricsEnabled =
false;
115 this.host =
"http://localhost:11434";
124 if (host.endsWith(
"/")) {
125 this.host = host.substring(0, host.length() - 1);
129 LOG.info(
"Ollama4j client initialized. Connected to Ollama server at: {}", this.host);
139 this.auth =
new BasicAuth(username, password);
158 long startTime = System.currentTimeMillis();
159 String url =
"/api/tags";
163 HttpClient httpClient = HttpClient.newHttpClient();
164 HttpRequest httpRequest;
165 HttpResponse<String> response;
167 getRequestBuilderDefault(
new URI(this.host + url))
176 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
177 statusCode = response.statusCode();
178 return statusCode == 200;
179 }
catch (InterruptedException ie) {
180 Thread.currentThread().interrupt();
182 }
catch (Exception e) {
206 long startTime = System.currentTimeMillis();
207 String url =
"/api/ps";
211 HttpClient httpClient = HttpClient.newHttpClient();
212 HttpRequest httpRequest =
null;
215 getRequestBuilderDefault(
new URI(this.host + url))
224 }
catch (URISyntaxException e) {
227 HttpResponse<String> response =
null;
228 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
229 statusCode = response.statusCode();
230 String responseString = response.body();
231 if (statusCode == 200) {
237 }
catch (InterruptedException ie) {
238 Thread.currentThread().interrupt();
240 }
catch (Exception e) {
264 long startTime = System.currentTimeMillis();
265 String url =
"/api/tags";
269 HttpClient httpClient = HttpClient.newHttpClient();
270 HttpRequest httpRequest =
271 getRequestBuilderDefault(
new URI(this.host + url))
280 HttpResponse<String> response =
281 httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
282 statusCode = response.statusCode();
283 String responseString = response.body();
284 if (statusCode == 200) {
291 }
catch (InterruptedException ie) {
292 Thread.currentThread().interrupt();
294 }
catch (Exception e) {
320 private void handlePullRetry(
321 String modelName,
int currentRetry,
int maxRetries,
long baseDelayMillis)
322 throws InterruptedException {
323 int attempt = currentRetry + 1;
324 if (attempt < maxRetries) {
325 long backoffMillis = baseDelayMillis * (1L << currentRetry);
327 "Failed to pull model {}, retrying in {}s... (attempt {}/{})",
329 backoffMillis / 1000,
333 Thread.sleep(backoffMillis);
334 }
catch (InterruptedException ie) {
335 Thread.currentThread().interrupt();
340 "Failed to pull model {} after {} attempts, no more retries.",
353 long startTime = System.currentTimeMillis();
354 String url =
"/api/pull";
359 HttpRequest request =
360 getRequestBuilderDefault(
new URI(this.host + url))
361 .POST(HttpRequest.BodyPublishers.ofString(jsonData))
369 HttpClient client = HttpClient.newHttpClient();
370 HttpResponse<InputStream> response =
371 client.send(request, HttpResponse.BodyHandlers.ofInputStream());
372 statusCode = response.statusCode();
373 InputStream responseBodyStream = response.body();
374 String responseString =
"";
375 boolean success =
false;
377 try (BufferedReader reader =
379 new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
381 while ((line = reader.readLine()) !=
null) {
384 success = processModelPullResponse(modelPullResponse, modelName) || success;
388 LOG.error(
"Model pull failed or returned invalid status.");
389 throw new OllamaException(
"Model pull failed or returned invalid status.");
391 if (statusCode != 200) {
394 }
catch (InterruptedException ie) {
395 Thread.currentThread().interrupt();
396 throw new OllamaException(
"Thread was interrupted during model pull.", ie);
397 }
catch (Exception e) {
423 @SuppressWarnings(
"RedundantIfStatement")
424 private
boolean processModelPullResponse(
ModelPullResponse modelPullResponse, String modelName)
426 if (modelPullResponse ==
null) {
427 LOG.error(
"Received null response for model pull.");
430 String error = modelPullResponse.getError();
431 if (error !=
null && !error.trim().isEmpty()) {
432 throw new OllamaException(
"Model pull failed: " + error);
434 String status = modelPullResponse.getStatus();
435 if (status !=
null) {
436 LOG.debug(
"{}: {}", modelName, status);
437 if (
"success".equalsIgnoreCase(status)) {
451 String url =
"/api/version";
452 long startTime = System.currentTimeMillis();
456 HttpClient httpClient = HttpClient.newHttpClient();
457 HttpRequest httpRequest =
458 getRequestBuilderDefault(
new URI(this.host + url))
467 HttpResponse<String> response =
468 httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
469 statusCode = response.statusCode();
470 String responseString = response.body();
471 if (statusCode == 200) {
478 }
catch (InterruptedException ie) {
479 Thread.currentThread().interrupt();
481 }
catch (Exception e) {
508 if (numberOfRetriesForModelPull == 0) {
509 this.doPullModel(modelName);
512 int numberOfRetries = 0;
513 long baseDelayMillis = 3000L;
514 while (numberOfRetries < numberOfRetriesForModelPull) {
516 this.doPullModel(modelName);
522 numberOfRetriesForModelPull,
528 "Failed to pull model "
531 + numberOfRetriesForModelPull
533 }
catch (InterruptedException ie) {
534 Thread.currentThread().interrupt();
536 }
catch (Exception e) {
549 long startTime = System.currentTimeMillis();
550 String url =
"/api/show";
555 HttpRequest request =
556 getRequestBuilderDefault(
new URI(this.host + url))
563 .POST(HttpRequest.BodyPublishers.ofString(jsonData))
565 HttpClient client = HttpClient.newHttpClient();
566 HttpResponse<String> response =
567 client.send(request, HttpResponse.BodyHandlers.ofString());
568 statusCode = response.statusCode();
569 String responseBody = response.body();
570 if (statusCode == 200) {
575 }
catch (InterruptedException ie) {
576 Thread.currentThread().interrupt();
578 }
catch (Exception e) {
603 long startTime = System.currentTimeMillis();
604 String url =
"/api/create";
608 String jsonData = customModelRequest.toString();
609 HttpRequest request =
610 getRequestBuilderDefault(
new URI(this.host + url))
618 HttpRequest.BodyPublishers.ofString(
619 jsonData, StandardCharsets.UTF_8))
621 HttpClient client = HttpClient.newHttpClient();
622 HttpResponse<InputStream> response =
623 client.send(request, HttpResponse.BodyHandlers.ofInputStream());
624 statusCode = response.statusCode();
625 if (statusCode != 200) {
627 new String(response.body().readAllBytes(), StandardCharsets.UTF_8);
631 try (BufferedReader reader =
633 new InputStreamReader(response.body(), StandardCharsets.UTF_8))) {
635 StringBuilder lines =
new StringBuilder();
636 while ((line = reader.readLine()) !=
null) {
640 LOG.debug(res.getStatus());
641 if (res.getError() !=
null) {
642 out = res.getError();
648 }
catch (InterruptedException e) {
649 Thread.currentThread().interrupt();
651 }
catch (Exception e) {
676 long startTime = System.currentTimeMillis();
677 String url =
"/api/delete";
682 HttpRequest request =
683 getRequestBuilderDefault(
new URI(this.host + url))
686 HttpRequest.BodyPublishers.ofString(
687 jsonData, StandardCharsets.UTF_8))
695 HttpClient client = HttpClient.newHttpClient();
696 HttpResponse<String> response =
697 client.send(request, HttpResponse.BodyHandlers.ofString());
698 statusCode = response.statusCode();
699 String responseBody = response.body();
701 if (statusCode == 404
702 && responseBody.contains(
"model")
703 && responseBody.contains(
"not found")) {
706 if (statusCode != 200) {
709 }
catch (InterruptedException e) {
710 Thread.currentThread().interrupt();
712 }
catch (Exception e) {
739 long startTime = System.currentTimeMillis();
740 String url =
"/api/generate";
744 ObjectMapper objectMapper =
new ObjectMapper();
745 Map<String, Object> jsonMap =
new java.util.HashMap<>();
746 jsonMap.put(
"model", modelName);
747 jsonMap.put(
"keep_alive", 0);
748 String jsonData = objectMapper.writeValueAsString(jsonMap);
749 HttpRequest request =
750 getRequestBuilderDefault(
new URI(this.host + url))
753 HttpRequest.BodyPublishers.ofString(
754 jsonData, StandardCharsets.UTF_8))
762 LOG.debug(
"Unloading model with request: {}", jsonData);
763 HttpClient client = HttpClient.newHttpClient();
764 HttpResponse<String> response =
765 client.send(request, HttpResponse.BodyHandlers.ofString());
766 statusCode = response.statusCode();
767 String responseBody = response.body();
768 if (statusCode == 404
769 && responseBody.contains(
"model")
770 && responseBody.contains(
"not found")) {
771 LOG.debug(
"Unload response: {} - {}", statusCode, responseBody);
774 if (statusCode != 200) {
775 LOG.debug(
"Unload response: {} - {}", statusCode, responseBody);
778 }
catch (InterruptedException e) {
779 Thread.currentThread().interrupt();
780 LOG.debug(
"Unload interrupted: {} - {}", statusCode, out);
782 }
catch (Exception e) {
783 LOG.debug(
"Unload failed: {} - {}", statusCode, out);
808 long startTime = System.currentTimeMillis();
809 String url =
"/api/embed";
814 HttpClient httpClient = HttpClient.newHttpClient();
815 HttpRequest request =
816 HttpRequest.newBuilder(
new URI(this.host + url))
820 .POST(HttpRequest.BodyPublishers.ofString(jsonData))
822 HttpResponse<String> response =
823 httpClient.send(request, HttpResponse.BodyHandlers.ofString());
824 statusCode = response.statusCode();
825 String responseBody = response.body();
826 if (statusCode == 200) {
831 }
catch (InterruptedException e) {
832 Thread.currentThread().interrupt();
834 }
catch (Exception e) {
864 if (request.isUseTools()) {
865 return generateWithToolsInternal(request, streamObserver);
868 if (streamObserver !=
null) {
870 return generateSyncForOllamaRequestModel(
872 streamObserver.getThinkingStreamHandler(),
873 streamObserver.getResponseStreamHandler());
875 return generateSyncForOllamaRequestModel(
876 request,
null, streamObserver.getResponseStreamHandler());
879 return generateSyncForOllamaRequestModel(request,
null,
null);
880 }
catch (Exception e) {
889 ArrayList<OllamaChatMessage> msgs =
new ArrayList<>();
891 chatRequest.setModel(request.getModel());
894 ocm.setResponse(request.getPrompt());
895 chatRequest.setMessages(msgs);
900 List<
Tools.
Tool> allTools =
new ArrayList<>();
901 if (request.getTools() !=
null) {
902 allTools.addAll(request.getTools());
904 List<
Tools.
Tool> registeredTools = this.getRegisteredTools();
905 if (registeredTools !=
null) {
906 allTools.addAll(registeredTools);
910 chatRequest.setUseTools(
true);
911 chatRequest.setTools(allTools);
912 if (streamObserver !=
null) {
913 chatRequest.setStream(
true);
914 if (streamObserver.getResponseStreamHandler() !=
null) {
918 .getResponseStreamHandler()
919 .accept(chatResponseModel.getMessage().getResponse());
924 res.getResponseModel().getMessage().getResponse(),
925 res.getResponseModel().getMessage().getThinking(),
926 res.getResponseModel().getTotalDuration(),
942 long startTime = System.currentTimeMillis();
943 String url =
"/api/generate";
947 ollamaRequestModel.setRaw(raw);
948 ollamaRequestModel.setThink(think);
951 getRequestBuilderDefault(
new URI(this.host + url)),
953 requestTimeoutSeconds);
954 ollamaAsyncResultStreamer.start();
955 statusCode = ollamaAsyncResultStreamer.getHttpStatusCode();
956 return ollamaAsyncResultStreamer;
957 }
catch (Exception e) {
961 url, model, raw, think,
true,
null,
null, startTime, statusCode,
null);
985 if (request.isUseTools()) {
990 if (tokenHandler !=
null) {
991 request.setStream(
true);
992 result = requestCaller.
call(request, tokenHandler);
994 result = requestCaller.
callSync(request);
998 List<OllamaChatToolCalls> toolCalls =
999 result.getResponseModel().getMessage().getToolCalls();
1000 int toolCallTries = 0;
1001 while (toolCalls !=
null
1002 && !toolCalls.isEmpty()
1003 && toolCallTries < maxChatToolCallRetries) {
1005 String toolName = toolCall.getFunction().getName();
1006 for (
Tools.
Tool t : request.getTools()) {
1007 if (t.getToolSpec().getName().equals(toolName)) {
1009 if (toolFunction ==
null) {
1011 "Tool function not found: " + toolName);
1014 "Invoking tool {} with arguments: {}",
1015 toolCall.getFunction().getName(),
1016 toolCall.getFunction().getArguments());
1017 Map<String, Object> arguments = toolCall.getFunction().getArguments();
1018 Object res = toolFunction.
apply(arguments);
1019 String argumentKeys =
1020 arguments.keySet().stream()
1021 .map(Object::toString)
1022 .collect(Collectors.joining(
", "));
1023 request.getMessages()
1033 +
" [/TOOL_RESULTS]"));
1037 if (tokenHandler !=
null) {
1038 result = requestCaller.
call(request, tokenHandler);
1040 result = requestCaller.
callSync(request);
1042 toolCalls = result.getResponseModel().getMessage().getToolCalls();
1046 }
catch (InterruptedException e) {
1047 Thread.currentThread().interrupt();
1049 }
catch (Exception e) {
1061 LOG.debug(
"Registered tool: {}", tool.getToolSpec().getName());
1083 toolRegistry.
clear();
1084 LOG.debug(
"All tools have been deregistered.");
1097 Class<?> callerClass =
null;
1100 Class.forName(Thread.currentThread().getStackTrace()[2].getClassName());
1101 }
catch (ClassNotFoundException e) {
1107 if (ollamaToolServiceAnnotation ==
null) {
1108 throw new IllegalStateException(
1112 Class<?>[] providers = ollamaToolServiceAnnotation.
providers();
1113 for (Class<?> provider : providers) {
1114 registerAnnotatedTools(provider.getDeclaredConstructor().newInstance());
1116 }
catch (InstantiationException
1117 | NoSuchMethodException
1118 | IllegalAccessException
1119 | InvocationTargetException e) {
1134 Class<?> objectClass =
object.getClass();
1135 Method[] methods = objectClass.getMethods();
1136 for (Method m : methods) {
1138 if (toolSpec ==
null) {
1141 String operationName = !toolSpec.
name().isBlank() ? toolSpec.
name() : m.getName();
1142 String operationDesc = !toolSpec.
desc().isBlank() ? toolSpec.
desc() : operationName;
1145 LinkedHashMap<String, String> methodParams =
new LinkedHashMap<>();
1146 for (Parameter parameter : m.getParameters()) {
1149 String propType = parameter.getType().getTypeName();
1150 if (toolPropertyAnn ==
null) {
1151 methodParams.put(parameter.getName(),
null);
1155 !toolPropertyAnn.
name().isBlank()
1156 ? toolPropertyAnn.
name()
1157 : parameter.getName();
1158 methodParams.put(propName, propType);
1163 .description(toolPropertyAnn.
desc())
1164 .required(toolPropertyAnn.
required())
1167 Tools.ToolSpec toolSpecification =
1169 .name(operationName)
1170 .description(operationDesc)
1177 .toolFunction(reflectionalToolFunction)
1178 .toolSpec(toolSpecification)
1222 private static String encodeFileToBase64(File file)
throws IOException {
1223 return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
1232 private static String encodeByteArrayToBase64(
byte[] bytes) {
1233 return Base64.getEncoder().encodeToString(bytes);
1247 private OllamaResult generateSyncForOllamaRequestModel(
1252 long startTime = System.currentTimeMillis();
1253 int statusCode = -1;
1259 if (responseStreamHandler !=
null) {
1260 ollamaRequestModel.setStream(
true);
1263 ollamaRequestModel, thinkingStreamHandler, responseStreamHandler);
1265 result = requestCaller.
callSync(ollamaRequestModel);
1267 statusCode = result.getHttpStatusCode();
1270 }
catch (InterruptedException e) {
1271 Thread.currentThread().interrupt();
1273 }
catch (Exception e) {
1278 ollamaRequestModel.getModel(),
1279 ollamaRequestModel.isRaw(),
1280 ollamaRequestModel.getThink(),
1281 ollamaRequestModel.isStream(),
1282 ollamaRequestModel.getOptions(),
1283 ollamaRequestModel.getFormat(),
1296 private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
1297 HttpRequest.Builder requestBuilder =
1298 HttpRequest.newBuilder(uri)
1302 .timeout(Duration.ofSeconds(requestTimeoutSeconds));
1306 return requestBuilder;
1314 private boolean isAuthSet() {
1315 return auth !=
null;