58 private static final Logger LOG = LoggerFactory.getLogger(
Ollama.class);
60 private final String host;
71 @Setter
private long requestTimeoutSeconds = 10;
76 @Setter
private int imageURLReadTimeoutSeconds = 10;
81 @Setter
private int imageURLConnectTimeoutSeconds = 10;
89 @Setter
private int maxChatToolCallRetries = 3;
100 @SuppressWarnings({
"FieldMayBeFinal",
"FieldCanBeLocal"})
101 private int numberOfRetriesForModelPull = 0;
109 @Setter
private boolean metricsEnabled =
false;
115 this.host =
"http://localhost:11434";
124 if (host.endsWith(
"/")) {
125 this.host = host.substring(0, host.length() - 1);
129 LOG.info(
"Ollama4j client initialized. Connected to Ollama server at: {}", this.host);
139 this.auth =
new BasicAuth(username, password);
158 long startTime = System.currentTimeMillis();
159 String url =
"/api/tags";
163 HttpClient httpClient = HttpClient.newHttpClient();
164 HttpRequest httpRequest;
165 HttpResponse<String> response;
167 getRequestBuilderDefault(
new URI(this.host + url))
176 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
177 statusCode = response.statusCode();
178 return statusCode == 200;
179 }
catch (InterruptedException ie) {
180 Thread.currentThread().interrupt();
182 }
catch (Exception e) {
186 url,
"",
false,
false,
false,
null,
null, startTime, statusCode, out);
197 long startTime = System.currentTimeMillis();
198 String url =
"/api/ps";
202 HttpClient httpClient = HttpClient.newHttpClient();
203 HttpRequest httpRequest =
null;
206 getRequestBuilderDefault(
new URI(this.host + url))
215 }
catch (URISyntaxException e) {
218 HttpResponse<String> response =
null;
219 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
220 statusCode = response.statusCode();
221 String responseString = response.body();
222 if (statusCode == 200) {
228 }
catch (InterruptedException ie) {
229 Thread.currentThread().interrupt();
231 }
catch (Exception e) {
235 url,
"",
false,
false,
false,
null,
null, startTime, statusCode, out);
246 long startTime = System.currentTimeMillis();
247 String url =
"/api/tags";
251 HttpClient httpClient = HttpClient.newHttpClient();
252 HttpRequest httpRequest =
253 getRequestBuilderDefault(
new URI(this.host + url))
262 HttpResponse<String> response =
263 httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
264 statusCode = response.statusCode();
265 String responseString = response.body();
266 if (statusCode == 200) {
273 }
catch (InterruptedException ie) {
274 Thread.currentThread().interrupt();
276 }
catch (Exception e) {
280 url,
"",
false,
false,
false,
null,
null, startTime, statusCode, out);
293 private void handlePullRetry(
294 String modelName,
int currentRetry,
int maxRetries,
long baseDelayMillis)
295 throws InterruptedException {
296 int attempt = currentRetry + 1;
297 if (attempt < maxRetries) {
298 long backoffMillis = baseDelayMillis * (1L << currentRetry);
300 "Failed to pull model {}, retrying in {}s... (attempt {}/{})",
302 backoffMillis / 1000,
306 Thread.sleep(backoffMillis);
307 }
catch (InterruptedException ie) {
308 Thread.currentThread().interrupt();
313 "Failed to pull model {} after {} attempts, no more retries.",
326 long startTime = System.currentTimeMillis();
327 String url =
"/api/pull";
332 HttpRequest request =
333 getRequestBuilderDefault(
new URI(this.host + url))
334 .POST(HttpRequest.BodyPublishers.ofString(jsonData))
342 HttpClient client = HttpClient.newHttpClient();
343 HttpResponse<InputStream> response =
344 client.send(request, HttpResponse.BodyHandlers.ofInputStream());
345 statusCode = response.statusCode();
346 InputStream responseBodyStream = response.body();
347 String responseString =
"";
348 boolean success =
false;
350 try (BufferedReader reader =
352 new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
354 while ((line = reader.readLine()) !=
null) {
357 success = processModelPullResponse(modelPullResponse, modelName) || success;
361 LOG.error(
"Model pull failed or returned invalid status.");
362 throw new OllamaException(
"Model pull failed or returned invalid status.");
364 if (statusCode != 200) {
367 }
catch (InterruptedException ie) {
368 Thread.currentThread().interrupt();
369 throw new OllamaException(
"Thread was interrupted during model pull.", ie);
370 }
catch (Exception e) {
374 url,
"",
false,
false,
false,
null,
null, startTime, statusCode, out);
387 @SuppressWarnings(
"RedundantIfStatement")
388 private
boolean processModelPullResponse(
ModelPullResponse modelPullResponse, String modelName)
390 if (modelPullResponse ==
null) {
391 LOG.error(
"Received null response for model pull.");
394 String error = modelPullResponse.getError();
395 if (error !=
null && !error.trim().isEmpty()) {
396 throw new OllamaException(
"Model pull failed: " + error);
398 String status = modelPullResponse.getStatus();
399 if (status !=
null) {
400 LOG.debug(
"{}: {}", modelName, status);
401 if (
"success".equalsIgnoreCase(status)) {
415 String url =
"/api/version";
416 long startTime = System.currentTimeMillis();
420 HttpClient httpClient = HttpClient.newHttpClient();
421 HttpRequest httpRequest =
422 getRequestBuilderDefault(
new URI(this.host + url))
431 HttpResponse<String> response =
432 httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
433 statusCode = response.statusCode();
434 String responseString = response.body();
435 if (statusCode == 200) {
442 }
catch (InterruptedException ie) {
443 Thread.currentThread().interrupt();
445 }
catch (Exception e) {
449 url,
"",
false,
false,
false,
null,
null, startTime, statusCode, out);
463 if (numberOfRetriesForModelPull == 0) {
464 this.doPullModel(modelName);
467 int numberOfRetries = 0;
468 long baseDelayMillis = 3000L;
469 while (numberOfRetries < numberOfRetriesForModelPull) {
471 this.doPullModel(modelName);
477 numberOfRetriesForModelPull,
483 "Failed to pull model "
486 + numberOfRetriesForModelPull
488 }
catch (InterruptedException ie) {
489 Thread.currentThread().interrupt();
491 }
catch (Exception e) {
504 long startTime = System.currentTimeMillis();
505 String url =
"/api/show";
510 HttpRequest request =
511 getRequestBuilderDefault(
new URI(this.host + url))
518 .POST(HttpRequest.BodyPublishers.ofString(jsonData))
520 HttpClient client = HttpClient.newHttpClient();
521 HttpResponse<String> response =
522 client.send(request, HttpResponse.BodyHandlers.ofString());
523 statusCode = response.statusCode();
524 String responseBody = response.body();
525 if (statusCode == 200) {
530 }
catch (InterruptedException ie) {
531 Thread.currentThread().interrupt();
533 }
catch (Exception e) {
537 url,
"",
false,
false,
false,
null,
null, startTime, statusCode, out);
549 long startTime = System.currentTimeMillis();
550 String url =
"/api/create";
554 String jsonData = customModelRequest.toString();
555 HttpRequest request =
556 getRequestBuilderDefault(
new URI(this.host + url))
564 HttpRequest.BodyPublishers.ofString(
565 jsonData, StandardCharsets.UTF_8))
567 HttpClient client = HttpClient.newHttpClient();
568 HttpResponse<InputStream> response =
569 client.send(request, HttpResponse.BodyHandlers.ofInputStream());
570 statusCode = response.statusCode();
571 if (statusCode != 200) {
573 new String(response.body().readAllBytes(), StandardCharsets.UTF_8);
577 try (BufferedReader reader =
579 new InputStreamReader(response.body(), StandardCharsets.UTF_8))) {
581 StringBuilder lines =
new StringBuilder();
582 while ((line = reader.readLine()) !=
null) {
586 LOG.debug(res.getStatus());
587 if (res.getError() !=
null) {
588 out = res.getError();
594 }
catch (InterruptedException e) {
595 Thread.currentThread().interrupt();
597 }
catch (Exception e) {
601 url,
"",
false,
false,
false,
null,
null, startTime, statusCode, out);
613 long startTime = System.currentTimeMillis();
614 String url =
"/api/delete";
619 HttpRequest request =
620 getRequestBuilderDefault(
new URI(this.host + url))
623 HttpRequest.BodyPublishers.ofString(
624 jsonData, StandardCharsets.UTF_8))
632 HttpClient client = HttpClient.newHttpClient();
633 HttpResponse<String> response =
634 client.send(request, HttpResponse.BodyHandlers.ofString());
635 statusCode = response.statusCode();
636 String responseBody = response.body();
638 if (statusCode == 404
639 && responseBody.contains(
"model")
640 && responseBody.contains(
"not found")) {
643 if (statusCode != 200) {
646 }
catch (InterruptedException e) {
647 Thread.currentThread().interrupt();
649 }
catch (Exception e) {
653 url,
"",
false,
false,
false,
null,
null, startTime, statusCode, out);
667 long startTime = System.currentTimeMillis();
668 String url =
"/api/generate";
672 ObjectMapper objectMapper =
new ObjectMapper();
673 Map<String, Object> jsonMap =
new java.util.HashMap<>();
674 jsonMap.put(
"model", modelName);
675 jsonMap.put(
"keep_alive", 0);
676 String jsonData = objectMapper.writeValueAsString(jsonMap);
677 HttpRequest request =
678 getRequestBuilderDefault(
new URI(this.host + url))
681 HttpRequest.BodyPublishers.ofString(
682 jsonData, StandardCharsets.UTF_8))
690 LOG.debug(
"Unloading model with request: {}", jsonData);
691 HttpClient client = HttpClient.newHttpClient();
692 HttpResponse<String> response =
693 client.send(request, HttpResponse.BodyHandlers.ofString());
694 statusCode = response.statusCode();
695 String responseBody = response.body();
696 if (statusCode == 404
697 && responseBody.contains(
"model")
698 && responseBody.contains(
"not found")) {
699 LOG.debug(
"Unload response: {} - {}", statusCode, responseBody);
702 if (statusCode != 200) {
703 LOG.debug(
"Unload response: {} - {}", statusCode, responseBody);
706 }
catch (InterruptedException e) {
707 Thread.currentThread().interrupt();
708 LOG.debug(
"Unload interrupted: {} - {}", statusCode, out);
710 }
catch (Exception e) {
711 LOG.debug(
"Unload failed: {} - {}", statusCode, out);
715 url,
"",
false,
false,
false,
null,
null, startTime, statusCode, out);
727 long startTime = System.currentTimeMillis();
728 String url =
"/api/embed";
733 HttpClient httpClient = HttpClient.newHttpClient();
734 HttpRequest request =
735 HttpRequest.newBuilder(
new URI(this.host + url))
739 .POST(HttpRequest.BodyPublishers.ofString(jsonData))
741 HttpResponse<String> response =
742 httpClient.send(request, HttpResponse.BodyHandlers.ofString());
743 statusCode = response.statusCode();
744 String responseBody = response.body();
745 if (statusCode == 200) {
750 }
catch (InterruptedException e) {
751 Thread.currentThread().interrupt();
753 }
catch (Exception e) {
757 url,
"",
false,
false,
false,
null,
null, startTime, statusCode, out);
774 if (request.isUseTools()) {
775 return generateWithToolsInternal(request, streamObserver);
778 if (streamObserver !=
null) {
779 if (request.isThink()) {
780 return generateSyncForOllamaRequestModel(
782 streamObserver.getThinkingStreamHandler(),
783 streamObserver.getResponseStreamHandler());
785 return generateSyncForOllamaRequestModel(
786 request,
null, streamObserver.getResponseStreamHandler());
789 return generateSyncForOllamaRequestModel(request,
null,
null);
790 }
catch (Exception e) {
799 ArrayList<OllamaChatMessage> msgs =
new ArrayList<>();
801 chatRequest.setModel(request.getModel());
804 ocm.setResponse(request.getPrompt());
805 chatRequest.setMessages(msgs);
808 chatRequest.setUseTools(
true);
809 chatRequest.setTools(request.getTools());
810 if (streamObserver !=
null) {
811 chatRequest.setStream(
true);
812 if (streamObserver.getResponseStreamHandler() !=
null) {
816 .getResponseStreamHandler()
817 .accept(chatResponseModel.getMessage().getResponse());
822 res.getResponseModel().getMessage().getResponse(),
823 res.getResponseModel().getMessage().getThinking(),
824 res.getResponseModel().getTotalDuration(),
839 String model, String prompt,
boolean raw,
boolean think)
throws OllamaException {
840 long startTime = System.currentTimeMillis();
841 String url =
"/api/generate";
845 ollamaRequestModel.setRaw(raw);
846 ollamaRequestModel.setThink(think);
849 getRequestBuilderDefault(
new URI(this.host + url)),
851 requestTimeoutSeconds);
852 ollamaAsyncResultStreamer.start();
853 statusCode = ollamaAsyncResultStreamer.getHttpStatusCode();
854 return ollamaAsyncResultStreamer;
855 }
catch (Exception e) {
859 url, model, raw, think,
true,
null,
null, startTime, statusCode,
null);
883 if (request.isUseTools()) {
888 if (tokenHandler !=
null) {
889 request.setStream(
true);
890 result = requestCaller.
call(request, tokenHandler);
892 result = requestCaller.
callSync(request);
896 List<OllamaChatToolCalls> toolCalls =
897 result.getResponseModel().getMessage().getToolCalls();
898 int toolCallTries = 0;
899 while (toolCalls !=
null
900 && !toolCalls.isEmpty()
901 && toolCallTries < maxChatToolCallRetries) {
903 String toolName = toolCall.getFunction().getName();
904 for (
Tools.
Tool t : request.getTools()) {
905 if (t.getToolSpec().getName().equals(toolName)) {
907 if (toolFunction ==
null) {
909 "Tool function not found: " + toolName);
912 "Invoking tool {} with arguments: {}",
913 toolCall.getFunction().getName(),
914 toolCall.getFunction().getArguments());
915 Map<String, Object> arguments = toolCall.getFunction().getArguments();
916 Object res = toolFunction.
apply(arguments);
917 String argumentKeys =
918 arguments.keySet().stream()
919 .map(Object::toString)
920 .collect(Collectors.joining(
", "));
921 request.getMessages()
931 +
" [/TOOL_RESULTS]"));
935 if (tokenHandler !=
null) {
936 result = requestCaller.
call(request, tokenHandler);
938 result = requestCaller.
callSync(request);
940 toolCalls = result.getResponseModel().getMessage().getToolCalls();
944 }
catch (InterruptedException e) {
945 Thread.currentThread().interrupt();
947 }
catch (Exception e) {
959 LOG.debug(
"Registered tool: {}", tool.getToolSpec().getName());
981 toolRegistry.
clear();
982 LOG.debug(
"All tools have been deregistered.");
995 Class<?> callerClass =
null;
998 Class.forName(Thread.currentThread().getStackTrace()[2].getClassName());
999 }
catch (ClassNotFoundException e) {
1005 if (ollamaToolServiceAnnotation ==
null) {
1006 throw new IllegalStateException(
1010 Class<?>[] providers = ollamaToolServiceAnnotation.
providers();
1011 for (Class<?> provider : providers) {
1012 registerAnnotatedTools(provider.getDeclaredConstructor().newInstance());
1014 }
catch (InstantiationException
1015 | NoSuchMethodException
1016 | IllegalAccessException
1017 | InvocationTargetException e) {
1032 Class<?> objectClass =
object.getClass();
1033 Method[] methods = objectClass.getMethods();
1034 for (Method m : methods) {
1036 if (toolSpec ==
null) {
1039 String operationName = !toolSpec.
name().isBlank() ? toolSpec.
name() : m.getName();
1040 String operationDesc = !toolSpec.
desc().isBlank() ? toolSpec.
desc() : operationName;
1043 LinkedHashMap<String, String> methodParams =
new LinkedHashMap<>();
1044 for (Parameter parameter : m.getParameters()) {
1047 String propType = parameter.getType().getTypeName();
1048 if (toolPropertyAnn ==
null) {
1049 methodParams.put(parameter.getName(),
null);
1053 !toolPropertyAnn.
name().isBlank()
1054 ? toolPropertyAnn.
name()
1055 : parameter.getName();
1056 methodParams.put(propName, propType);
1061 .description(toolPropertyAnn.
desc())
1062 .required(toolPropertyAnn.
required())
1065 Tools.ToolSpec toolSpecification =
1067 .name(operationName)
1068 .description(operationDesc)
1075 .toolFunction(reflectionalToolFunction)
1076 .toolSpec(toolSpecification)
1120 private static String encodeFileToBase64(File file)
throws IOException {
1121 return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
1130 private static String encodeByteArrayToBase64(
byte[] bytes) {
1131 return Base64.getEncoder().encodeToString(bytes);
1145 private OllamaResult generateSyncForOllamaRequestModel(
1150 long startTime = System.currentTimeMillis();
1151 int statusCode = -1;
1157 if (responseStreamHandler !=
null) {
1158 ollamaRequestModel.setStream(
true);
1161 ollamaRequestModel, thinkingStreamHandler, responseStreamHandler);
1163 result = requestCaller.
callSync(ollamaRequestModel);
1165 statusCode = result.getHttpStatusCode();
1168 }
catch (InterruptedException e) {
1169 Thread.currentThread().interrupt();
1171 }
catch (Exception e) {
1176 ollamaRequestModel.getModel(),
1177 ollamaRequestModel.isRaw(),
1178 ollamaRequestModel.isThink(),
1179 ollamaRequestModel.isStream(),
1180 ollamaRequestModel.getOptions(),
1181 ollamaRequestModel.getFormat(),
1194 private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
1195 HttpRequest.Builder requestBuilder =
1196 HttpRequest.newBuilder(uri)
1200 .timeout(Duration.ofSeconds(requestTimeoutSeconds));
1204 return requestBuilder;
1212 private boolean isAuthSet() {
1213 return auth !=
null;