58 private static final Logger LOG = LoggerFactory.getLogger(
Ollama.class);
60 private final String host;
71 @Setter
private long requestTimeoutSeconds = 10;
74 @Setter
private int imageURLReadTimeoutSeconds = 10;
77 @Setter
private int imageURLConnectTimeoutSeconds = 10;
85 @Setter
private int maxChatToolCallRetries = 3;
96 @SuppressWarnings({
"FieldMayBeFinal",
"FieldCanBeLocal"})
97 private int numberOfRetriesForModelPull = 0;
105 @Setter
private boolean metricsEnabled =
false;
111 this.host =
"http://localhost:11434";
120 if (host.endsWith(
"/")) {
121 this.host = host.substring(0, host.length() - 1);
125 LOG.info(
"Ollama4j client initialized. Connected to Ollama server at: {}", this.host);
135 this.auth =
new BasicAuth(username, password);
154 long startTime = System.currentTimeMillis();
155 String url =
"/api/tags";
159 HttpClient httpClient = HttpClient.newHttpClient();
160 HttpRequest httpRequest;
161 HttpResponse<String> response;
163 getRequestBuilderDefault(
new URI(this.host + url))
172 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
173 statusCode = response.statusCode();
174 return statusCode == 200;
175 }
catch (InterruptedException ie) {
176 Thread.currentThread().interrupt();
178 }
catch (Exception e) {
182 url,
"",
false,
false,
false,
null,
null, startTime, statusCode, out);
193 long startTime = System.currentTimeMillis();
194 String url =
"/api/ps";
198 HttpClient httpClient = HttpClient.newHttpClient();
199 HttpRequest httpRequest =
null;
202 getRequestBuilderDefault(
new URI(this.host + url))
211 }
catch (URISyntaxException e) {
214 HttpResponse<String> response =
null;
215 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
216 statusCode = response.statusCode();
217 String responseString = response.body();
218 if (statusCode == 200) {
224 }
catch (InterruptedException ie) {
225 Thread.currentThread().interrupt();
227 }
catch (Exception e) {
231 url,
"",
false,
false,
false,
null,
null, startTime, statusCode, out);
242 long startTime = System.currentTimeMillis();
243 String url =
"/api/tags";
247 HttpClient httpClient = HttpClient.newHttpClient();
248 HttpRequest httpRequest =
249 getRequestBuilderDefault(
new URI(this.host + url))
258 HttpResponse<String> response =
259 httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
260 statusCode = response.statusCode();
261 String responseString = response.body();
262 if (statusCode == 200) {
269 }
catch (InterruptedException ie) {
270 Thread.currentThread().interrupt();
272 }
catch (Exception e) {
276 url,
"",
false,
false,
false,
null,
null, startTime, statusCode, out);
289 private void handlePullRetry(
290 String modelName,
int currentRetry,
int maxRetries,
long baseDelayMillis)
291 throws InterruptedException {
292 int attempt = currentRetry + 1;
293 if (attempt < maxRetries) {
294 long backoffMillis = baseDelayMillis * (1L << currentRetry);
296 "Failed to pull model {}, retrying in {}s... (attempt {}/{})",
298 backoffMillis / 1000,
302 Thread.sleep(backoffMillis);
303 }
catch (InterruptedException ie) {
304 Thread.currentThread().interrupt();
309 "Failed to pull model {} after {} attempts, no more retries.",
322 long startTime = System.currentTimeMillis();
323 String url =
"/api/pull";
328 HttpRequest request =
329 getRequestBuilderDefault(
new URI(this.host + url))
330 .POST(HttpRequest.BodyPublishers.ofString(jsonData))
338 HttpClient client = HttpClient.newHttpClient();
339 HttpResponse<InputStream> response =
340 client.send(request, HttpResponse.BodyHandlers.ofInputStream());
341 statusCode = response.statusCode();
342 InputStream responseBodyStream = response.body();
343 String responseString =
"";
344 boolean success =
false;
346 try (BufferedReader reader =
348 new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
350 while ((line = reader.readLine()) !=
null) {
353 success = processModelPullResponse(modelPullResponse, modelName) || success;
357 LOG.error(
"Model pull failed or returned invalid status.");
358 throw new OllamaException(
"Model pull failed or returned invalid status.");
360 if (statusCode != 200) {
363 }
catch (InterruptedException ie) {
364 Thread.currentThread().interrupt();
365 throw new OllamaException(
"Thread was interrupted during model pull.", ie);
366 }
catch (Exception e) {
370 url,
"",
false,
false,
false,
null,
null, startTime, statusCode, out);
383 @SuppressWarnings(
"RedundantIfStatement")
384 private
boolean processModelPullResponse(
ModelPullResponse modelPullResponse, String modelName)
386 if (modelPullResponse ==
null) {
387 LOG.error(
"Received null response for model pull.");
390 String error = modelPullResponse.getError();
391 if (error !=
null && !error.trim().isEmpty()) {
392 throw new OllamaException(
"Model pull failed: " + error);
394 String status = modelPullResponse.getStatus();
395 if (status !=
null) {
396 LOG.debug(
"{}: {}", modelName, status);
397 if (
"success".equalsIgnoreCase(status)) {
411 String url =
"/api/version";
412 long startTime = System.currentTimeMillis();
416 HttpClient httpClient = HttpClient.newHttpClient();
417 HttpRequest httpRequest =
418 getRequestBuilderDefault(
new URI(this.host + url))
427 HttpResponse<String> response =
428 httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
429 statusCode = response.statusCode();
430 String responseString = response.body();
431 if (statusCode == 200) {
438 }
catch (InterruptedException ie) {
439 Thread.currentThread().interrupt();
441 }
catch (Exception e) {
445 url,
"",
false,
false,
false,
null,
null, startTime, statusCode, out);
459 if (numberOfRetriesForModelPull == 0) {
460 this.doPullModel(modelName);
463 int numberOfRetries = 0;
464 long baseDelayMillis = 3000L;
465 while (numberOfRetries < numberOfRetriesForModelPull) {
467 this.doPullModel(modelName);
473 numberOfRetriesForModelPull,
479 "Failed to pull model "
482 + numberOfRetriesForModelPull
484 }
catch (InterruptedException ie) {
485 Thread.currentThread().interrupt();
487 }
catch (Exception e) {
500 long startTime = System.currentTimeMillis();
501 String url =
"/api/show";
506 HttpRequest request =
507 getRequestBuilderDefault(
new URI(this.host + url))
514 .POST(HttpRequest.BodyPublishers.ofString(jsonData))
516 HttpClient client = HttpClient.newHttpClient();
517 HttpResponse<String> response =
518 client.send(request, HttpResponse.BodyHandlers.ofString());
519 statusCode = response.statusCode();
520 String responseBody = response.body();
521 if (statusCode == 200) {
526 }
catch (InterruptedException ie) {
527 Thread.currentThread().interrupt();
529 }
catch (Exception e) {
533 url,
"",
false,
false,
false,
null,
null, startTime, statusCode, out);
545 long startTime = System.currentTimeMillis();
546 String url =
"/api/create";
550 String jsonData = customModelRequest.toString();
551 HttpRequest request =
552 getRequestBuilderDefault(
new URI(this.host + url))
560 HttpRequest.BodyPublishers.ofString(
561 jsonData, StandardCharsets.UTF_8))
563 HttpClient client = HttpClient.newHttpClient();
564 HttpResponse<InputStream> response =
565 client.send(request, HttpResponse.BodyHandlers.ofInputStream());
566 statusCode = response.statusCode();
567 if (statusCode != 200) {
569 new String(response.body().readAllBytes(), StandardCharsets.UTF_8);
573 try (BufferedReader reader =
575 new InputStreamReader(response.body(), StandardCharsets.UTF_8))) {
577 StringBuilder lines =
new StringBuilder();
578 while ((line = reader.readLine()) !=
null) {
582 LOG.debug(res.getStatus());
583 if (res.getError() !=
null) {
584 out = res.getError();
590 }
catch (InterruptedException e) {
591 Thread.currentThread().interrupt();
593 }
catch (Exception e) {
597 url,
"",
false,
false,
false,
null,
null, startTime, statusCode, out);
609 long startTime = System.currentTimeMillis();
610 String url =
"/api/delete";
615 HttpRequest request =
616 getRequestBuilderDefault(
new URI(this.host + url))
619 HttpRequest.BodyPublishers.ofString(
620 jsonData, StandardCharsets.UTF_8))
628 HttpClient client = HttpClient.newHttpClient();
629 HttpResponse<String> response =
630 client.send(request, HttpResponse.BodyHandlers.ofString());
631 statusCode = response.statusCode();
632 String responseBody = response.body();
634 if (statusCode == 404
635 && responseBody.contains(
"model")
636 && responseBody.contains(
"not found")) {
639 if (statusCode != 200) {
642 }
catch (InterruptedException e) {
643 Thread.currentThread().interrupt();
645 }
catch (Exception e) {
649 url,
"",
false,
false,
false,
null,
null, startTime, statusCode, out);
663 long startTime = System.currentTimeMillis();
664 String url =
"/api/generate";
668 ObjectMapper objectMapper =
new ObjectMapper();
669 Map<String, Object> jsonMap =
new java.util.HashMap<>();
670 jsonMap.put(
"model", modelName);
671 jsonMap.put(
"keep_alive", 0);
672 String jsonData = objectMapper.writeValueAsString(jsonMap);
673 HttpRequest request =
674 getRequestBuilderDefault(
new URI(this.host + url))
677 HttpRequest.BodyPublishers.ofString(
678 jsonData, StandardCharsets.UTF_8))
686 LOG.debug(
"Unloading model with request: {}", jsonData);
687 HttpClient client = HttpClient.newHttpClient();
688 HttpResponse<String> response =
689 client.send(request, HttpResponse.BodyHandlers.ofString());
690 statusCode = response.statusCode();
691 String responseBody = response.body();
692 if (statusCode == 404
693 && responseBody.contains(
"model")
694 && responseBody.contains(
"not found")) {
695 LOG.debug(
"Unload response: {} - {}", statusCode, responseBody);
698 if (statusCode != 200) {
699 LOG.debug(
"Unload response: {} - {}", statusCode, responseBody);
702 }
catch (InterruptedException e) {
703 Thread.currentThread().interrupt();
704 LOG.debug(
"Unload interrupted: {} - {}", statusCode, out);
706 }
catch (Exception e) {
707 LOG.debug(
"Unload failed: {} - {}", statusCode, out);
711 url,
"",
false,
false,
false,
null,
null, startTime, statusCode, out);
723 long startTime = System.currentTimeMillis();
724 String url =
"/api/embed";
729 HttpClient httpClient = HttpClient.newHttpClient();
730 HttpRequest request =
731 HttpRequest.newBuilder(
new URI(this.host + url))
735 .POST(HttpRequest.BodyPublishers.ofString(jsonData))
737 HttpResponse<String> response =
738 httpClient.send(request, HttpResponse.BodyHandlers.ofString());
739 statusCode = response.statusCode();
740 String responseBody = response.body();
741 if (statusCode == 200) {
746 }
catch (InterruptedException e) {
747 Thread.currentThread().interrupt();
749 }
catch (Exception e) {
753 url,
"",
false,
false,
false,
null,
null, startTime, statusCode, out);
770 if (request.isUseTools()) {
771 return generateWithToolsInternal(request, streamObserver);
774 if (streamObserver !=
null) {
775 if (request.isThink()) {
776 return generateSyncForOllamaRequestModel(
778 streamObserver.getThinkingStreamHandler(),
779 streamObserver.getResponseStreamHandler());
781 return generateSyncForOllamaRequestModel(
782 request,
null, streamObserver.getResponseStreamHandler());
785 return generateSyncForOllamaRequestModel(request,
null,
null);
786 }
catch (Exception e) {
795 ArrayList<OllamaChatMessage> msgs =
new ArrayList<>();
797 chatRequest.setModel(request.getModel());
800 ocm.setResponse(request.getPrompt());
801 chatRequest.setMessages(msgs);
804 chatRequest.setTools(request.getTools());
805 if (streamObserver !=
null) {
806 chatRequest.setStream(
true);
807 if (streamObserver.getResponseStreamHandler() !=
null) {
811 .getResponseStreamHandler()
812 .accept(chatResponseModel.getMessage().getResponse());
817 res.getResponseModel().getMessage().getResponse(),
818 res.getResponseModel().getMessage().getThinking(),
819 res.getResponseModel().getTotalDuration(),
834 String model, String prompt,
boolean raw,
boolean think)
throws OllamaException {
835 long startTime = System.currentTimeMillis();
836 String url =
"/api/generate";
840 ollamaRequestModel.setRaw(raw);
841 ollamaRequestModel.setThink(think);
844 getRequestBuilderDefault(
new URI(this.host + url)),
846 requestTimeoutSeconds);
847 ollamaAsyncResultStreamer.start();
848 statusCode = ollamaAsyncResultStreamer.getHttpStatusCode();
849 return ollamaAsyncResultStreamer;
850 }
catch (Exception e) {
854 url, model, raw, think,
true,
null,
null, startTime, statusCode,
null);
878 if (request.isUseTools()) {
883 if (tokenHandler !=
null) {
884 request.setStream(
true);
885 result = requestCaller.
call(request, tokenHandler);
887 result = requestCaller.
callSync(request);
891 List<OllamaChatToolCalls> toolCalls =
892 result.getResponseModel().getMessage().getToolCalls();
893 int toolCallTries = 0;
894 while (toolCalls !=
null
895 && !toolCalls.isEmpty()
896 && toolCallTries < maxChatToolCallRetries) {
898 String toolName = toolCall.getFunction().getName();
899 for (
Tools.
Tool t : request.getTools()) {
900 if (t.getToolSpec().getName().equals(toolName)) {
902 if (toolFunction ==
null) {
904 "Tool function not found: " + toolName);
907 "Invoking tool {} with arguments: {}",
908 toolCall.getFunction().getName(),
909 toolCall.getFunction().getArguments());
910 Map<String, Object> arguments = toolCall.getFunction().getArguments();
911 Object res = toolFunction.
apply(arguments);
912 String argumentKeys =
913 arguments.keySet().stream()
914 .map(Object::toString)
915 .collect(Collectors.joining(
", "));
916 request.getMessages()
926 +
" [/TOOL_RESULTS]"));
930 if (tokenHandler !=
null) {
931 result = requestCaller.
call(request, tokenHandler);
933 result = requestCaller.
callSync(request);
935 toolCalls = result.getResponseModel().getMessage().getToolCalls();
939 }
catch (InterruptedException e) {
940 Thread.currentThread().interrupt();
942 }
catch (Exception e) {
954 LOG.debug(
"Registered tool: {}", tool.getToolSpec().getName());
972 toolRegistry.
clear();
973 LOG.debug(
"All tools have been deregistered.");
986 Class<?> callerClass =
null;
989 Class.forName(Thread.currentThread().getStackTrace()[2].getClassName());
990 }
catch (ClassNotFoundException e) {
996 if (ollamaToolServiceAnnotation ==
null) {
997 throw new IllegalStateException(
1001 Class<?>[] providers = ollamaToolServiceAnnotation.
providers();
1002 for (Class<?> provider : providers) {
1003 registerAnnotatedTools(provider.getDeclaredConstructor().newInstance());
1005 }
catch (InstantiationException
1006 | NoSuchMethodException
1007 | IllegalAccessException
1008 | InvocationTargetException e) {
1023 Class<?> objectClass =
object.getClass();
1024 Method[] methods = objectClass.getMethods();
1025 for (Method m : methods) {
1027 if (toolSpec ==
null) {
1030 String operationName = !toolSpec.
name().isBlank() ? toolSpec.
name() : m.getName();
1031 String operationDesc = !toolSpec.
desc().isBlank() ? toolSpec.
desc() : operationName;
1034 LinkedHashMap<String, String> methodParams =
new LinkedHashMap<>();
1035 for (Parameter parameter : m.getParameters()) {
1038 String propType = parameter.getType().getTypeName();
1039 if (toolPropertyAnn ==
null) {
1040 methodParams.put(parameter.getName(),
null);
1044 !toolPropertyAnn.
name().isBlank()
1045 ? toolPropertyAnn.
name()
1046 : parameter.getName();
1047 methodParams.put(propName, propType);
1052 .description(toolPropertyAnn.
desc())
1053 .required(toolPropertyAnn.
required())
1056 Tools.ToolSpec toolSpecification =
1058 .name(operationName)
1059 .description(operationDesc)
1066 .toolFunction(reflectionalToolFunction)
1067 .toolSpec(toolSpecification)
1111 private static String encodeFileToBase64(File file)
throws IOException {
1112 return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
1121 private static String encodeByteArrayToBase64(
byte[] bytes) {
1122 return Base64.getEncoder().encodeToString(bytes);
1136 private OllamaResult generateSyncForOllamaRequestModel(
1141 long startTime = System.currentTimeMillis();
1142 int statusCode = -1;
1148 if (responseStreamHandler !=
null) {
1149 ollamaRequestModel.setStream(
true);
1152 ollamaRequestModel, thinkingStreamHandler, responseStreamHandler);
1154 result = requestCaller.
callSync(ollamaRequestModel);
1156 statusCode = result.getHttpStatusCode();
1159 }
catch (InterruptedException e) {
1160 Thread.currentThread().interrupt();
1162 }
catch (Exception e) {
1167 ollamaRequestModel.getModel(),
1168 ollamaRequestModel.isRaw(),
1169 ollamaRequestModel.isThink(),
1170 ollamaRequestModel.isStream(),
1171 ollamaRequestModel.getOptions(),
1172 ollamaRequestModel.getFormat(),
1185 private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
1186 HttpRequest.Builder requestBuilder =
1187 HttpRequest.newBuilder(uri)
1191 .timeout(Duration.ofSeconds(requestTimeoutSeconds));
1195 return requestBuilder;
1203 private boolean isAuthSet() {
1204 return auth !=
null;