1package io.github.ollama4j;
3import io.github.ollama4j.exceptions.OllamaBaseException;
4import io.github.ollama4j.exceptions.RoleNotFoundException;
5import io.github.ollama4j.exceptions.ToolInvocationException;
6import io.github.ollama4j.exceptions.ToolNotFoundException;
7import io.github.ollama4j.models.chat.*;
8import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
9import io.github.ollama4j.models.embeddings.OllamaEmbeddingResponseModel;
10import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
11import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
12import io.github.ollama4j.models.generate.OllamaGenerateRequest;
13import io.github.ollama4j.models.generate.OllamaStreamHandler;
14import io.github.ollama4j.models.generate.OllamaTokenHandler;
15import io.github.ollama4j.models.ps.ModelsProcessResponse;
16import io.github.ollama4j.models.request.*;
17import io.github.ollama4j.models.response.*;
18import io.github.ollama4j.tools.*;
19import io.github.ollama4j.tools.annotations.OllamaToolService;
20import io.github.ollama4j.tools.annotations.ToolProperty;
21import io.github.ollama4j.tools.annotations.ToolSpec;
22import io.github.ollama4j.utils.Options;
23import io.github.ollama4j.utils.Utils;
27import java.lang.reflect.InvocationTargetException;
28import java.lang.reflect.Method;
29import java.lang.reflect.Parameter;
31import java.net.URISyntaxException;
32import java.net.http.HttpClient;
33import java.net.http.HttpConnectTimeoutException;
34import java.net.http.HttpRequest;
35import java.net.http.HttpResponse;
36import java.nio.charset.StandardCharsets;
37import java.nio.file.Files;
38import java.time.Duration;
40import java.util.stream.Collectors;
42import org.slf4j.Logger;
43import org.slf4j.LoggerFactory;
44import org.jsoup.Jsoup;
45import org.jsoup.nodes.Document;
46import org.jsoup.nodes.Element;
47import org.jsoup.select.Elements;
52@SuppressWarnings({
"DuplicatedCode",
"resource"})
55 private static final Logger logger = LoggerFactory.getLogger(
OllamaAPI.class);
56 private final String host;
62 private long requestTimeoutSeconds = 10;
68 private boolean verbose =
true;
71 private int maxChatToolCallRetries = 3;
81 this.host =
"http://localhost:11434";
90 if (host.endsWith(
"/")) {
91 this.host = host.substring(0, host.length() - 1);
104 this.basicAuth =
new BasicAuth(username, password);
113 String url = this.host +
"/api/tags";
114 HttpClient httpClient = HttpClient.newHttpClient();
115 HttpRequest httpRequest =
null;
117 httpRequest = getRequestBuilderDefault(
new URI(url)).header(
"Accept",
"application/json").header(
"Content-type",
"application/json").GET().build();
118 }
catch (URISyntaxException e) {
119 throw new RuntimeException(e);
121 HttpResponse<String>
response =
null;
123 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
124 }
catch (HttpConnectTimeoutException e) {
126 }
catch (IOException | InterruptedException e) {
127 throw new RuntimeException(e);
129 int statusCode =
response.statusCode();
130 return statusCode == 200;
142 String url = this.host +
"/api/ps";
143 HttpClient httpClient = HttpClient.newHttpClient();
144 HttpRequest httpRequest =
null;
146 httpRequest = getRequestBuilderDefault(
new URI(url)).header(
"Accept",
"application/json").header(
"Content-type",
"application/json").GET().build();
147 }
catch (URISyntaxException e) {
148 throw new RuntimeException(e);
150 HttpResponse<String>
response =
null;
151 response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
152 int statusCode =
response.statusCode();
153 String responseString =
response.body();
154 if (statusCode == 200) {
171 String url = this.host +
"/api/tags";
172 HttpClient httpClient = HttpClient.newHttpClient();
173 HttpRequest httpRequest = getRequestBuilderDefault(
new URI(url)).header(
"Accept",
"application/json").header(
"Content-type",
"application/json").GET().build();
174 HttpResponse<String>
response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
175 int statusCode =
response.statusCode();
176 String responseString =
response.body();
177 if (statusCode == 200) {
195 String url =
"https://ollama.com/library";
196 HttpClient httpClient = HttpClient.newHttpClient();
197 HttpRequest httpRequest = getRequestBuilderDefault(
new URI(url)).header(
"Accept",
"application/json").header(
"Content-type",
"application/json").GET().build();
198 HttpResponse<String>
response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
199 int statusCode =
response.statusCode();
200 String responseString =
response.body();
201 List<LibraryModel> models =
new ArrayList<>();
202 if (statusCode == 200) {
203 Document doc = Jsoup.parse(responseString);
204 Elements modelSections = doc.selectXpath(
"//*[@id='repo']/ul/li/a");
205 for (Element e : modelSections) {
207 Elements names = e.select(
"div > h2 > div > span");
208 Elements desc = e.select(
"div > p");
209 Elements pullCounts = e.select(
"div:nth-of-type(2) > p > span:first-of-type > span:first-of-type");
210 Elements popularTags = e.select(
"div > div > span");
211 Elements totalTags = e.select(
"div:nth-of-type(2) > p > span:nth-of-type(2) > span:first-of-type");
212 Elements lastUpdatedTime = e.select(
"div:nth-of-type(2) > p > span:nth-of-type(3) > span:nth-of-type(2)");
214 if (names.first() ==
null || names.isEmpty()) {
218 Optional.ofNullable(names.first()).map(Element::text).ifPresent(model::setName);
219 model.setDescription(Optional.ofNullable(desc.first()).map(Element::text).orElse(
""));
220 model.setPopularTags(Optional.of(popularTags).map(tags -> tags.stream().map(Element::text).collect(Collectors.toList())).orElse(
new ArrayList<>()));
221 model.setPullCount(Optional.ofNullable(pullCounts.first()).map(Element::text).orElse(
""));
222 model.setTotalTags(Optional.ofNullable(totalTags.first()).map(Element::text).map(Integer::parseInt).orElse(0));
223 model.setLastUpdated(Optional.ofNullable(lastUpdatedTime.first()).map(Element::text).orElse(
""));
248 String url = String.format(
"https://ollama.com/library/%s/tags", libraryModel.getName());
249 HttpClient httpClient = HttpClient.newHttpClient();
250 HttpRequest httpRequest = getRequestBuilderDefault(
new URI(url)).header(
"Accept",
"application/json").header(
"Content-type",
"application/json").GET().build();
251 HttpResponse<String>
response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
252 int statusCode =
response.statusCode();
253 String responseString =
response.body();
255 List<LibraryModelTag> libraryModelTags =
new ArrayList<>();
256 if (statusCode == 200) {
257 Document doc = Jsoup.parse(responseString);
258 Elements tagSections = doc.select(
"html > body > main > div > section > div > div > div:nth-child(n+2) > div");
259 for (Element e : tagSections) {
260 Elements tags = e.select(
"div > a > div");
261 Elements tagsMetas = e.select(
"div > span");
265 if (tags.first() ==
null || tags.isEmpty()) {
269 libraryModelTag.setName(libraryModel.getName());
270 Optional.ofNullable(tags.first()).map(Element::text).ifPresent(libraryModelTag::setTag);
271 libraryModelTag.setSize(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split(
"•")).filter(parts -> parts.length > 1).map(parts -> parts[1].trim()).orElse(
""));
272 libraryModelTag.setLastUpdated(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split(
"•")).filter(parts -> parts.length > 1).map(parts -> parts[2].trim()).orElse(
""));
273 libraryModelTags.add(libraryModelTag);
276 libraryModelDetail.setModel(libraryModel);
277 libraryModelDetail.setTags(libraryModelTags);
278 return libraryModelDetail;
302 LibraryModel libraryModel = libraryModels.stream().filter(model -> model.getName().equals(modelName)).findFirst().orElseThrow(() ->
new NoSuchElementException(String.format(
"Model by name '%s' not found", modelName)));
304 LibraryModelTag libraryModelTag = libraryModelDetail.getTags().stream().filter(tagName -> tagName.getTag().equals(tag)).findFirst().orElseThrow(() ->
new NoSuchElementException(String.format(
"Tag '%s' for model '%s' not found", tag, modelName)));
305 return libraryModelTag;
319 String url = this.host +
"/api/pull";
321 HttpRequest
request = getRequestBuilderDefault(
new URI(url)).POST(HttpRequest.BodyPublishers.ofString(jsonData)).header(
"Accept",
"application/json").header(
"Content-type",
"application/json").build();
322 HttpClient client = HttpClient.newHttpClient();
323 HttpResponse<InputStream>
response = client.send(
request, HttpResponse.BodyHandlers.ofInputStream());
324 int statusCode =
response.statusCode();
325 InputStream responseBodyStream =
response.body();
326 String responseString =
"";
327 try (BufferedReader reader =
new BufferedReader(
new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
329 while ((line = reader.readLine()) !=
null) {
332 logger.info(modelPullResponse.getStatus());
336 if (statusCode != 200) {
354 String tagToPull = String.format(
"%s:%s", libraryModelTag.getName(), libraryModelTag.getTag());
369 String url = this.host +
"/api/show";
371 HttpRequest
request = getRequestBuilderDefault(
new URI(url)).header(
"Accept",
"application/json").header(
"Content-type",
"application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
372 HttpClient client = HttpClient.newHttpClient();
373 HttpResponse<String>
response = client.send(
request, HttpResponse.BodyHandlers.ofString());
374 int statusCode =
response.statusCode();
375 String responseBody =
response.body();
376 if (statusCode == 200) {
395 String url = this.host +
"/api/create";
397 HttpRequest
request = getRequestBuilderDefault(
new URI(url)).header(
"Accept",
"application/json").header(
"Content-Type",
"application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
398 HttpClient client = HttpClient.newHttpClient();
399 HttpResponse<String>
response = client.send(
request, HttpResponse.BodyHandlers.ofString());
400 int statusCode =
response.statusCode();
401 String responseString =
response.body();
402 if (statusCode != 200) {
407 if (responseString.contains(
"error")) {
411 logger.info(responseString);
427 String url = this.host +
"/api/create";
429 HttpRequest
request = getRequestBuilderDefault(
new URI(url)).header(
"Accept",
"application/json").header(
"Content-Type",
"application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
430 HttpClient client = HttpClient.newHttpClient();
431 HttpResponse<String>
response = client.send(
request, HttpResponse.BodyHandlers.ofString());
432 int statusCode =
response.statusCode();
433 String responseString =
response.body();
434 if (statusCode != 200) {
437 if (responseString.contains(
"error")) {
441 logger.info(responseString);
456 String url = this.host +
"/api/delete";
458 HttpRequest
request = getRequestBuilderDefault(
new URI(url)).method(
"DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).header(
"Accept",
"application/json").header(
"Content-type",
"application/json").build();
459 HttpClient client = HttpClient.newHttpClient();
460 HttpResponse<String>
response = client.send(
request, HttpResponse.BodyHandlers.ofString());
461 int statusCode =
response.statusCode();
462 String responseBody =
response.body();
463 if (statusCode == 404 && responseBody.contains(
"model") && responseBody.contains(
"not found")) {
466 if (statusCode != 200) {
499 URI uri = URI.create(this.host +
"/api/embeddings");
500 String jsonData = modelRequest.toString();
501 HttpClient httpClient = HttpClient.newHttpClient();
502 HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).header(
"Accept",
"application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData));
503 HttpRequest
request = requestBuilder.build();
504 HttpResponse<String>
response = httpClient.send(
request, HttpResponse.BodyHandlers.ofString());
505 int statusCode =
response.statusCode();
506 String responseBody =
response.body();
507 if (statusCode == 200) {
509 return embeddingResponse.getEmbedding();
539 URI uri = URI.create(this.host +
"/api/embed");
541 HttpClient httpClient = HttpClient.newHttpClient();
543 HttpRequest
request = HttpRequest.newBuilder(uri).header(
"Accept",
"application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
545 HttpResponse<String>
response = httpClient.send(
request, HttpResponse.BodyHandlers.ofString());
546 int statusCode =
response.statusCode();
547 String responseBody =
response.body();
549 if (statusCode == 200) {
573 ollamaRequestModel.setRaw(raw);
574 ollamaRequestModel.setOptions(options.getOptionsMap());
575 return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
593 return generate(model, prompt, raw, options,
null);
611 Map<ToolFunctionCallSpec, Object> toolResults =
new HashMap<>();
613 if(!prompt.startsWith(
"[AVAILABLE_TOOLS]")){
616 promptBuilder.withToolSpecification(spec);
618 promptBuilder.withPrompt(prompt);
619 prompt = promptBuilder.build();
623 toolResult.setModelResult(result);
625 String toolsResponse = result.getResponse();
626 if (toolsResponse.contains(
"[TOOL_CALLS]")) {
627 toolsResponse = toolsResponse.replace(
"[TOOL_CALLS]",
"");
632 toolResults.put(toolFunctionCallSpec, invokeTool(toolFunctionCallSpec));
634 toolResult.setToolResults(toolResults);
649 ollamaRequestModel.setRaw(raw);
650 URI uri = URI.create(this.host +
"/api/generate");
652 ollamaAsyncResultStreamer.start();
653 return ollamaAsyncResultStreamer;
673 List<String> images =
new ArrayList<>();
674 for (File imageFile : imageFiles) {
675 images.add(encodeFileToBase64(imageFile));
678 ollamaRequestModel.setOptions(options.getOptionsMap());
679 return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
713 List<String> images =
new ArrayList<>();
714 for (String imageURL : imageURLs) {
718 ollamaRequestModel.setOptions(options.getOptionsMap());
719 return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
812 request.setTools(toolRegistry.getRegisteredSpecs().stream().map(
Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
814 if (tokenHandler !=
null) {
816 result = requestCaller.
call(
request, tokenHandler);
822 List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls();
823 int toolCallTries = 0;
824 while(toolCalls !=
null && !toolCalls.isEmpty() && toolCallTries < maxChatToolCallRetries){
826 String toolName = toolCall.getFunction().getName();
827 ToolFunction toolFunction = toolRegistry.getToolFunction(toolName);
828 Map<String, Object> arguments = toolCall.getFunction().getArguments();
829 Object res = toolFunction.
apply(arguments);
833 if (tokenHandler !=
null) {
834 result = requestCaller.
call(
request, tokenHandler);
838 toolCalls = result.getResponseModel().getMessage().getToolCalls();
846 toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
852 Class<?> callerClass =
null;
854 callerClass = Class.forName(Thread.currentThread().getStackTrace()[2].getClassName());
855 }
catch (ClassNotFoundException e) {
856 throw new RuntimeException(e);
860 if (ollamaToolServiceAnnotation ==
null) {
861 throw new IllegalStateException(callerClass +
" is not annotated as " +
OllamaToolService.class);
864 Class<?>[] providers = ollamaToolServiceAnnotation.
providers();
865 for (Class<?> provider : providers) {
868 }
catch (InstantiationException | NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
869 throw new RuntimeException(e);
874 Class<?> objectClass =
object.getClass();
875 Method[] methods = objectClass.getMethods();
876 for(Method m : methods) {
878 if(toolSpec ==
null){
881 String operationName = !toolSpec.
name().isBlank() ? toolSpec.
name() : m.getName();
882 String operationDesc = !toolSpec.
desc().isBlank() ? toolSpec.
desc() : operationName;
885 LinkedHashMap<String,String> methodParams =
new LinkedHashMap<>();
886 for (Parameter parameter : m.getParameters()) {
888 String propType = parameter.getType().getTypeName();
889 if(toolPropertyAnn ==
null) {
890 methodParams.put(parameter.getName(),
null);
893 String propName = !toolPropertyAnn.
name().isBlank() ? toolPropertyAnn.
name() : parameter.getName();
894 methodParams.put(propName,propType);
897 .description(toolPropertyAnn.
desc())
898 .required(toolPropertyAnn.
required())
902 List<String> reqProps = params.entrySet().stream()
903 .filter(e -> e.getValue().isRequired())
904 .map(Map.Entry::getKey)
905 .collect(Collectors.toList());
908 .functionName(operationName)
909 .functionDescription(operationDesc)
914 .description(operationDesc)
930 toolSpecification.setToolFunction(reflectionalToolFunction);
931 toolRegistry.addTool(toolSpecification.getFunctionName(),toolSpecification);
969 private static String encodeFileToBase64(File file)
throws IOException {
970 return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
973 private static String encodeByteArrayToBase64(
byte[] bytes) {
974 return Base64.getEncoder().encodeToString(bytes);
980 if (streamHandler !=
null) {
981 ollamaRequestModel.setStream(
true);
982 result = requestCaller.
call(ollamaRequestModel, streamHandler);
984 result = requestCaller.
callSync(ollamaRequestModel);
995 private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
996 HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header(
"Content-Type",
"application/json").timeout(Duration.ofSeconds(requestTimeoutSeconds));
997 if (isBasicAuthCredentialsSet()) {
998 requestBuilder.header(
"Authorization", getBasicAuthHeaderValue());
1000 return requestBuilder;
1008 private String getBasicAuthHeaderValue() {
1009 String credentialsToEncode = basicAuth.getUsername() +
":" + basicAuth.getPassword();
1010 return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
1018 private boolean isBasicAuthCredentialsSet() {
1019 return basicAuth !=
null;
1024 String methodName = toolFunctionCallSpec.getName();
1025 Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
1028 logger.debug(
"Invoking function {} with arguments {}", methodName, arguments);
1030 if (
function ==
null) {
1033 return function.
apply(arguments);
1034 }
catch (Exception e) {
OllamaResult generateWithImageFiles(String model, String prompt, List< File > imageFiles, Options options, OllamaStreamHandler streamHandler)
OllamaChatResult chat(OllamaChatRequest request)
ModelDetail getModelDetails(String modelName)
List< Double > generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest)
List< Model > listModels()
OllamaAsyncResultStreamer generateAsync(String model, String prompt, boolean raw)
OllamaChatResult chat(String model, List< OllamaChatMessage > messages)
List< Double > generateEmbeddings(String model, String prompt)
void pullModel(String modelName)
void setBasicAuth(String username, String password)
OllamaResult generate(String model, String prompt, boolean raw, Options options, OllamaStreamHandler streamHandler)
OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler)
void deleteModel(String modelName, boolean ignoreIfNotPresent)
void createModelWithFilePath(String modelName, String modelFilePath)
OllamaResult generate(String model, String prompt, boolean raw, Options options)
ModelsProcessResponse ps()
LibraryModelTag findModelTagFromLibrary(String modelName, String tag)
OllamaChatMessageRole getRole(String roleName)
OllamaChatMessageRole addCustomRole(String roleName)
void pullModel(LibraryModelTag libraryModelTag)
List< OllamaChatMessageRole > listRoles()
LibraryModelDetail getLibraryModelDetails(LibraryModel libraryModel)
OllamaResult generateWithImageURLs(String model, String prompt, List< String > imageURLs, Options options, OllamaStreamHandler streamHandler)
OllamaEmbedResponseModel embed(OllamaEmbedRequestModel modelRequest)
void registerTool(Tools.ToolSpecification toolSpecification)
void registerAnnotatedTools(Object object)
void registerAnnotatedTools()
List< LibraryModel > listModelsFromLibrary()
OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler)
OllamaToolsResult generateWithTools(String model, String prompt, Options options)
OllamaResult generateWithImageURLs(String model, String prompt, List< String > imageURLs, Options options)
OllamaEmbedResponseModel embed(String model, List< String > inputs)
void createModelWithModelFileContents(String modelName, String modelFileContents)
OllamaResult generateWithImageFiles(String model, String prompt, List< File > imageFiles, Options options)
static List< OllamaChatMessageRole > getRoles()
static OllamaChatMessageRole newCustomRole(String roleName)
static OllamaChatMessageRole getRole(String roleName)
static final OllamaChatMessageRole TOOL
OllamaChatRequestBuilder withMessages(List< OllamaChatMessage > messages)
OllamaChatRequest build()
static OllamaChatRequestBuilder getInstance(String model)
OllamaChatResult call(OllamaChatRequest body, OllamaTokenHandler tokenHandler)
OllamaChatResult callSync(OllamaChatRequest body)
OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
OllamaResult callSync(OllamaRequestBody body)
static byte[] loadImageBytesFromUrl(String imageUrl)
static ObjectMapper getObjectMapper()