70 private static final Logger LOG = LoggerFactory.getLogger(
Ollama.class);
72 private final String host;
83 @Setter
private long requestTimeoutSeconds = 10;
86 @Setter
private int imageURLReadTimeoutSeconds = 10;
89 @Setter
private int imageURLConnectTimeoutSeconds = 10;
97 @Setter
private int maxChatToolCallRetries = 3;
108 @SuppressWarnings({
"FieldMayBeFinal",
"FieldCanBeLocal"})
109 private int numberOfRetriesForModelPull = 0;
117 @Setter
private boolean metricsEnabled =
false;
121 this.host =
"http://localhost:11434";
130 if (host.endsWith(
"/")) {
131 this.host = host.substring(0, host.length() - 1);
135 LOG.info(
"Ollama4j client initialized. Connected to Ollama server at: {}", this.host);
145 this.auth =
new BasicAuth(username, password);
164 long startTime = System.currentTimeMillis();
165 String url =
"/api/tags";
169 HttpClient httpClient = HttpClient.newHttpClient();
170 HttpRequest httpRequest;
171 HttpResponse<String> response;
173 getRequestBuilderDefault(
new URI(this.host + url))
182 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
183 statusCode = response.statusCode();
184 return statusCode == 200;
185 }
catch (InterruptedException ie) {
186 Thread.currentThread().interrupt();
188 }
catch (Exception e) {
212 long startTime = System.currentTimeMillis();
213 String url =
"/api/ps";
217 HttpClient httpClient = HttpClient.newHttpClient();
218 HttpRequest httpRequest =
null;
221 getRequestBuilderDefault(
new URI(this.host + url))
230 }
catch (URISyntaxException e) {
233 HttpResponse<String> response =
null;
234 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
235 statusCode = response.statusCode();
236 String responseString = response.body();
237 if (statusCode == 200) {
243 }
catch (InterruptedException ie) {
244 Thread.currentThread().interrupt();
246 }
catch (Exception e) {
270 long startTime = System.currentTimeMillis();
271 String url =
"/api/tags";
275 HttpClient httpClient = HttpClient.newHttpClient();
276 HttpRequest httpRequest =
277 getRequestBuilderDefault(
new URI(this.host + url))
286 HttpResponse<String> response =
287 httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
288 statusCode = response.statusCode();
289 String responseString = response.body();
290 if (statusCode == 200) {
297 }
catch (InterruptedException ie) {
298 Thread.currentThread().interrupt();
300 }
catch (Exception e) {
326 private void handlePullRetry(
327 String modelName,
int currentRetry,
int maxRetries,
long baseDelayMillis)
328 throws InterruptedException {
329 int attempt = currentRetry + 1;
330 if (attempt < maxRetries) {
331 long backoffMillis = baseDelayMillis * (1L << currentRetry);
333 "Failed to pull model {}, retrying in {}s... (attempt {}/{})",
335 backoffMillis / 1000,
339 Thread.sleep(backoffMillis);
340 }
catch (InterruptedException ie) {
341 Thread.currentThread().interrupt();
346 "Failed to pull model {} after {} attempts, no more retries.",
359 long startTime = System.currentTimeMillis();
360 String url =
"/api/pull";
365 HttpRequest request =
366 getRequestBuilderDefault(
new URI(this.host + url))
367 .POST(HttpRequest.BodyPublishers.ofString(jsonData))
375 HttpClient client = HttpClient.newHttpClient();
376 HttpResponse<InputStream> response =
377 client.send(request, HttpResponse.BodyHandlers.ofInputStream());
378 statusCode = response.statusCode();
379 InputStream responseBodyStream = response.body();
380 String responseString =
"";
381 boolean success =
false;
383 try (BufferedReader reader =
385 new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
387 while ((line = reader.readLine()) !=
null) {
390 success = processModelPullResponse(modelPullResponse, modelName) || success;
394 LOG.error(
"Model pull failed or returned invalid status.");
395 throw new OllamaException(
"Model pull failed or returned invalid status.");
397 if (statusCode != 200) {
400 }
catch (InterruptedException ie) {
401 Thread.currentThread().interrupt();
402 throw new OllamaException(
"Thread was interrupted during model pull.", ie);
403 }
catch (Exception e) {
429 @SuppressWarnings(
"RedundantIfStatement")
430 private
boolean processModelPullResponse(
ModelPullResponse modelPullResponse, String modelName)
432 if (modelPullResponse ==
null) {
433 LOG.error(
"Received null response for model pull.");
436 String error = modelPullResponse.getError();
437 if (error !=
null && !error.trim().isEmpty()) {
438 throw new OllamaException(
"Model pull failed: " + error);
440 String status = modelPullResponse.getStatus();
441 if (status !=
null) {
442 LOG.debug(
"{}: {}", modelName, status);
443 if (
"success".equalsIgnoreCase(status)) {
457 String url =
"/api/version";
458 long startTime = System.currentTimeMillis();
462 HttpClient httpClient = HttpClient.newHttpClient();
463 HttpRequest httpRequest =
464 getRequestBuilderDefault(
new URI(this.host + url))
473 HttpResponse<String> response =
474 httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
475 statusCode = response.statusCode();
476 String responseString = response.body();
477 if (statusCode == 200) {
484 }
catch (InterruptedException ie) {
485 Thread.currentThread().interrupt();
487 }
catch (Exception e) {
514 if (numberOfRetriesForModelPull == 0) {
515 this.doPullModel(modelName);
518 int numberOfRetries = 0;
519 long baseDelayMillis = 3000L;
520 while (numberOfRetries < numberOfRetriesForModelPull) {
522 this.doPullModel(modelName);
528 numberOfRetriesForModelPull,
534 "Failed to pull model "
537 + numberOfRetriesForModelPull
539 }
catch (InterruptedException ie) {
540 Thread.currentThread().interrupt();
542 }
catch (Exception e) {
555 long startTime = System.currentTimeMillis();
556 String url =
"/api/show";
561 HttpRequest request =
562 getRequestBuilderDefault(
new URI(this.host + url))
569 .POST(HttpRequest.BodyPublishers.ofString(jsonData))
571 HttpClient client = HttpClient.newHttpClient();
572 HttpResponse<String> response =
573 client.send(request, HttpResponse.BodyHandlers.ofString());
574 statusCode = response.statusCode();
575 String responseBody = response.body();
576 if (statusCode == 200) {
581 }
catch (InterruptedException ie) {
582 Thread.currentThread().interrupt();
584 }
catch (Exception e) {
609 long startTime = System.currentTimeMillis();
610 String url =
"/api/create";
614 String jsonData = customModelRequest.toString();
615 HttpRequest request =
616 getRequestBuilderDefault(
new URI(this.host + url))
624 HttpRequest.BodyPublishers.ofString(
625 jsonData, StandardCharsets.UTF_8))
627 HttpClient client = HttpClient.newHttpClient();
628 HttpResponse<InputStream> response =
629 client.send(request, HttpResponse.BodyHandlers.ofInputStream());
630 statusCode = response.statusCode();
631 if (statusCode != 200) {
633 new String(response.body().readAllBytes(), StandardCharsets.UTF_8);
637 try (BufferedReader reader =
639 new InputStreamReader(response.body(), StandardCharsets.UTF_8))) {
641 StringBuilder lines =
new StringBuilder();
642 while ((line = reader.readLine()) !=
null) {
646 LOG.debug(res.getStatus());
647 if (res.getError() !=
null) {
648 out = res.getError();
654 }
catch (InterruptedException e) {
655 Thread.currentThread().interrupt();
657 }
catch (Exception e) {
683 long startTime = System.currentTimeMillis();
684 String url =
"/api/delete";
689 HttpRequest request =
690 getRequestBuilderDefault(
new URI(this.host + url))
693 HttpRequest.BodyPublishers.ofString(
694 jsonData, StandardCharsets.UTF_8))
702 HttpClient client = HttpClient.newHttpClient();
703 HttpResponse<String> response =
704 client.send(request, HttpResponse.BodyHandlers.ofString());
705 statusCode = response.statusCode();
706 String responseBody = response.body();
708 if (statusCode == 404
709 && responseBody.contains(
"model")
710 && responseBody.contains(
"not found")) {
713 if (statusCode != 200) {
716 }
catch (InterruptedException e) {
717 Thread.currentThread().interrupt();
719 }
catch (Exception e) {
746 long startTime = System.currentTimeMillis();
747 String url =
"/api/generate";
751 ObjectMapper objectMapper =
new ObjectMapper();
752 Map<String, Object> jsonMap =
new java.util.HashMap<>();
753 jsonMap.put(
"model", modelName);
754 jsonMap.put(
"keep_alive", 0);
755 String jsonData = objectMapper.writeValueAsString(jsonMap);
756 HttpRequest request =
757 getRequestBuilderDefault(
new URI(this.host + url))
760 HttpRequest.BodyPublishers.ofString(
761 jsonData, StandardCharsets.UTF_8))
769 LOG.debug(
"Unloading model with request: {}", jsonData);
770 HttpClient client = HttpClient.newHttpClient();
771 HttpResponse<String> response =
772 client.send(request, HttpResponse.BodyHandlers.ofString());
773 statusCode = response.statusCode();
774 String responseBody = response.body();
775 if (statusCode == 404
776 && responseBody.contains(
"model")
777 && responseBody.contains(
"not found")) {
778 LOG.debug(
"Unload response: {} - {}", statusCode, responseBody);
781 if (statusCode != 200) {
782 LOG.debug(
"Unload response: {} - {}", statusCode, responseBody);
785 }
catch (InterruptedException e) {
786 Thread.currentThread().interrupt();
787 LOG.debug(
"Unload interrupted: {} - {}", statusCode, out);
789 }
catch (Exception e) {
790 LOG.debug(
"Unload failed: {} - {}", statusCode, out);
815 long startTime = System.currentTimeMillis();
816 String url =
"/api/embed";
821 HttpClient httpClient = HttpClient.newHttpClient();
822 HttpRequest request =
823 HttpRequest.newBuilder(
new URI(this.host + url))
827 .POST(HttpRequest.BodyPublishers.ofString(jsonData))
829 HttpResponse<String> response =
830 httpClient.send(request, HttpResponse.BodyHandlers.ofString());
831 statusCode = response.statusCode();
832 String responseBody = response.body();
833 if (statusCode == 200) {
838 }
catch (InterruptedException e) {
839 Thread.currentThread().interrupt();
841 }
catch (Exception e) {
871 if (request.isUseTools()) {
872 return generateWithToolsInternal(request, streamObserver);
875 if (streamObserver !=
null) {
877 return generateSyncForOllamaRequestModel(
879 streamObserver.getThinkingStreamHandler(),
880 streamObserver.getResponseStreamHandler());
882 return generateSyncForOllamaRequestModel(
883 request,
null, streamObserver.getResponseStreamHandler());
886 return generateSyncForOllamaRequestModel(request,
null,
null);
887 }
catch (Exception e) {
896 ArrayList<OllamaChatMessage> msgs =
new ArrayList<>();
898 chatRequest.setModel(request.getModel());
901 ocm.setResponse(request.getPrompt());
902 chatRequest.setMessages(msgs);
908 List<
Tools.
Tool> allTools =
new ArrayList<>();
909 if (request.getTools() !=
null) {
910 allTools.addAll(request.getTools());
912 List<
Tools.
Tool> registeredTools = this.getRegisteredTools();
913 if (registeredTools !=
null) {
914 allTools.addAll(registeredTools);
918 chatRequest.setUseTools(
true);
919 chatRequest.setTools(allTools);
920 if (streamObserver !=
null) {
921 chatRequest.setStream(
true);
922 if (streamObserver.getResponseStreamHandler() !=
null) {
926 .getResponseStreamHandler()
927 .accept(chatResponseModel.getMessage().getResponse());
932 res.getResponseModel().getMessage().getResponse(),
933 res.getResponseModel().getMessage().getThinking(),
934 res.getResponseModel().getTotalDuration(),
950 long startTime = System.currentTimeMillis();
951 String url =
"/api/generate";
955 ollamaRequestModel.setRaw(raw);
956 ollamaRequestModel.setThink(think);
959 getRequestBuilderDefault(
new URI(this.host + url)),
961 requestTimeoutSeconds);
962 ollamaAsyncResultStreamer.start();
963 statusCode = ollamaAsyncResultStreamer.getHttpStatusCode();
964 return ollamaAsyncResultStreamer;
965 }
catch (Exception e) {
969 url, model, raw, think,
true,
null,
null, startTime, statusCode,
null);
993 if (request.isUseTools()) {
998 if (tokenHandler !=
null) {
999 request.setStream(
true);
1000 result = requestCaller.
call(request, tokenHandler);
1002 result = requestCaller.
callSync(request);
1006 List<OllamaChatToolCalls> toolCalls =
1007 result.getResponseModel().getMessage().getToolCalls();
1009 int toolCallTries = 0;
1010 while (toolCalls !=
null
1011 && !toolCalls.isEmpty()
1012 && toolCallTries < maxChatToolCallRetries) {
1014 String toolName = toolCall.getFunction().getName();
1015 for (
Tools.
Tool t : request.getTools()) {
1016 if (t.getToolSpec().getName().equals(toolName)) {
1018 if (toolFunction ==
null) {
1020 "Tool function not found: " + toolName);
1023 "Invoking tool {} with arguments: {}",
1024 toolCall.getFunction().getName(),
1025 toolCall.getFunction().getArguments());
1026 Map<String, Object> arguments = toolCall.getFunction().getArguments();
1027 Object res = toolFunction.
apply(arguments);
1028 String argumentKeys =
1029 arguments.keySet().stream()
1030 .map(Object::toString)
1031 .collect(Collectors.joining(
", "));
1032 request.getMessages()
1042 +
" [/TOOL_RESULTS]"));
1046 if (tokenHandler !=
null) {
1047 result = requestCaller.
call(request, tokenHandler);
1049 result = requestCaller.
callSync(request);
1051 toolCalls = result.getResponseModel().getMessage().getToolCalls();
1055 }
catch (InterruptedException e) {
1056 Thread.currentThread().interrupt();
1058 }
catch (Exception e) {
1070 LOG.debug(
"Registered tool: {}", tool.getToolSpec().getName());
1092 toolRegistry.
clear();
1093 LOG.debug(
"All tools have been deregistered.");
1106 Class<?> callerClass =
null;
1109 Class.forName(Thread.currentThread().getStackTrace()[2].getClassName());
1110 }
catch (ClassNotFoundException e) {
1116 if (ollamaToolServiceAnnotation ==
null) {
1117 throw new IllegalStateException(
1121 Class<?>[] providers = ollamaToolServiceAnnotation.
providers();
1122 for (Class<?> provider : providers) {
1123 registerAnnotatedTools(provider.getDeclaredConstructor().newInstance());
1125 }
catch (InstantiationException
1126 | NoSuchMethodException
1127 | IllegalAccessException
1128 | InvocationTargetException e) {
1143 Class<?> objectClass =
object.getClass();
1144 Method[] methods = objectClass.getMethods();
1145 for (Method m : methods) {
1147 if (toolSpec ==
null) {
1150 String operationName = !toolSpec.
name().isBlank() ? toolSpec.
name() : m.getName();
1151 String operationDesc = !toolSpec.
desc().isBlank() ? toolSpec.
desc() : operationName;
1154 LinkedHashMap<String, String> methodParams =
new LinkedHashMap<>();
1155 for (Parameter parameter : m.getParameters()) {
1158 String propType = parameter.getType().getTypeName();
1159 if (toolPropertyAnn ==
null) {
1160 methodParams.put(parameter.getName(),
null);
1164 !toolPropertyAnn.
name().isBlank()
1165 ? toolPropertyAnn.
name()
1166 : parameter.getName();
1167 methodParams.put(propName, propType);
1172 .description(toolPropertyAnn.
desc())
1173 .required(toolPropertyAnn.
required())
1176 Tools.ToolSpec toolSpecification =
1178 .name(operationName)
1179 .description(operationDesc)
1186 .toolFunction(reflectionalToolFunction)
1187 .toolSpec(toolSpecification)
1223 String jsonContent =
1224 java.nio.file.Files.readString(java.nio.file.Paths.get(mcpConfigJsonFilePath));
1225 MCPToolsConfig config =
1226 McpJsonMapper.getDefault().readValue(jsonContent, MCPToolsConfig.class);
1228 if (config.mcpServers !=
null && !config.mcpServers.isEmpty()) {
1229 for (Map.Entry<String, MCPToolConfig> tool : config.mcpServers.entrySet()) {
1230 ServerParameters.Builder serverParamsBuilder =
1231 ServerParameters.builder(tool.getValue().command);
1232 if (tool.getValue().args !=
null && !tool.getValue().args.isEmpty()) {
1234 "Runnable MCP Tool command: \n\n\t{} {}\n\n",
1235 tool.getValue().command,
1236 String.join(
" ", tool.getValue().args));
1237 serverParamsBuilder.args(tool.getValue().args.toArray(
new String[0]));
1239 ServerParameters serverParameters = serverParamsBuilder.build();
1240 StdioClientTransport transport =
1241 new StdioClientTransport(serverParameters, McpJsonMapper.getDefault());
1243 int mcpToolRequestTimeoutSeconds = 30;
1245 McpSyncClient client =
1246 McpClient.sync(transport)
1248 Duration.ofSeconds(mcpToolRequestTimeoutSeconds))
1250 client.initialize();
1252 ListToolsResult result = client.listTools();
1253 for (io.modelcontextprotocol.spec.McpSchema.Tool mcpTool : result.tools()) {
1254 Tools.Tool mcpToolAsOllama4jTool =
1255 createOllamaToolFromMCPTool(
1256 tool.getKey(), mcpTool, serverParameters);
1257 toolRegistry.
addTool(mcpToolAsOllama4jTool);
1279 private CallToolResult callMCPTool(
1280 String mcpServerName, String toolName, Map<String, Object> arguments) {
1281 for (
Tools.
Tool tool : getRegisteredTools()) {
1282 if (tool.isMCPTool() && tool.getMcpServerName().equals(mcpServerName)) {
1283 if (tool.getToolSpec().getName().equals(toolName)) {
1284 ServerParameters serverParameters = tool.getMcpServerParameters();
1285 StdioClientTransport stdioTransport =
1286 new StdioClientTransport(serverParameters, McpJsonMapper.getDefault());
1289 "Calling MCP Tool: '{}.{}' with arguments: {}",
1293 try (McpSyncClient client =
1294 McpClient.sync(stdioTransport)
1295 .requestTimeout(Duration.ofSeconds(requestTimeoutSeconds))
1297 client.initialize();
1298 CallToolRequest request =
new CallToolRequest(toolName, arguments);
1299 return client.callTool(request);
1302 stdioTransport.close();
1307 throw new IllegalArgumentException(
1308 "No MCP tool found for server name: "
1310 +
" and tool name: "
1323 private static String encodeFileToBase64(File file)
throws IOException {
1324 return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
1333 private static String encodeByteArrayToBase64(
byte[] bytes) {
1334 return Base64.getEncoder().encodeToString(bytes);
1350 private OllamaResult generateSyncForOllamaRequestModel(
1355 long startTime = System.currentTimeMillis();
1356 int statusCode = -1;
1362 if (responseStreamHandler !=
null) {
1363 ollamaRequestModel.setStream(
true);
1366 ollamaRequestModel, thinkingStreamHandler, responseStreamHandler);
1368 result = requestCaller.
callSync(ollamaRequestModel);
1370 statusCode = result.getHttpStatusCode();
1373 }
catch (InterruptedException e) {
1374 Thread.currentThread().interrupt();
1376 }
catch (Exception e) {
1381 ollamaRequestModel.getModel(),
1382 ollamaRequestModel.isRaw(),
1383 ollamaRequestModel.getThink(),
1384 ollamaRequestModel.isStream(),
1385 ollamaRequestModel.getOptions(),
1386 ollamaRequestModel.getFormat(),
1399 private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
1400 HttpRequest.Builder requestBuilder =
1401 HttpRequest.newBuilder(uri)
1405 .timeout(Duration.ofSeconds(requestTimeoutSeconds));
1409 return requestBuilder;
1417 private boolean isAuthSet() {
1418 return auth !=
null;
1424 public static class OllamaMCPTool {
1425 private String mcpServerName;
1426 private List<MCPToolInfo> toolInfos;
1427 private StdioClientTransport transport;
1433 public static class MCPToolInfo {
1434 private String toolName;
1435 private String toolDescription;
1438 public static class MCPToolConfig {
1439 @JsonProperty(
"command")
1440 public String command;
1442 @JsonProperty("args")
1443 public List<String> args;
1446 public static class MCPToolsConfig {
1447 @JsonProperty(
"mcpServers")
1448 public Map<String, MCPToolConfig> mcpServers;
1454 public static class OllamaMCPToolMatchResponse {
1455 @JsonProperty(
"mcpServerName")
1456 public String mcpServerName;
1458 @JsonProperty("toolName")
1459 public String toolName;
1461 @JsonProperty("arguments")
1462 public Map<String, Object> arguments;
1475 private
Tools.Tool createOllamaToolFromMCPTool(
1476 String mcpServerName,
1477 io.modelcontextprotocol.spec.McpSchema.Tool mcpTool,
1478 ServerParameters serverParameters) {
1479 Map<String,
Tools.
Property> properties =
new java.util.HashMap<>();
1480 java.util.List<String> requiredList =
new java.util.ArrayList<>();
1482 if (mcpTool.inputSchema() !=
null && mcpTool.inputSchema().properties() !=
null) {
1484 java.util.Set<String> requiredSet = new java.util.HashSet<>();
1485 if (mcpTool.inputSchema().required() != null) {
1486 requiredSet.addAll(mcpTool.inputSchema().required());
1488 for (Map.Entry<String, Object> entry : mcpTool.inputSchema().properties().entrySet()) {
1489 String propName = entry.getKey();
1490 Object propertyValue = entry.getValue();
1491 Map<String, Object> propertyMap = null;
1493 if (propertyValue instanceof Map) {
1494 propertyMap = (Map<String, Object>) propertyValue;
1502 propertyMap.get(
"type") !=
null ? propertyMap.get(
"type").toString() :
null;
1504 String description =
null;
1505 if (propertyMap.get(
"description") !=
null) {
1506 description = propertyMap.get(
"description").toString();
1507 }
else if (propertyMap.get(
"title") !=
null) {
1509 description = propertyMap.get(
"title").toString();
1513 boolean propRequired = requiredSet.contains(propName);
1518 .description(description)
1519 .required(propRequired)
1522 properties.put(propName, property);
1524 requiredList.add(propName);
1530 params.setProperties(properties);
1531 params.setRequired(requiredList);
1536 .name(mcpTool.name())
1537 .description(mcpTool.description())
1542 CallToolResult result =
1543 this.callMCPTool(mcpServerName, mcpTool.name(), arguments);
1544 return result.toString();
1547 .mcpServerName(mcpServerName)
1548 .mcpServerParameters(serverParameters)