Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
public class GeminiResource {

@POST
@Path("v1beta/models/gemini-1.5-flash:generateContent")
@Path("v1/models/gemini-1.5-flash:generateContent")
@Produces("application/json")
@Consumes("application/json")
public String generateResponse(String generateRequest, @RestQuery String key) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ void test() {

wiremock().register(
post(urlEqualTo(
String.format("/v1beta/models/%s:generateContent", CHAT_MODEL_ID)))
String.format("/v1/models/%s:generateContent", CHAT_MODEL_ID)))
.withHeader("Authorization", equalTo("Bearer " + API_KEY))
.willReturn(aResponse()
.withHeader("Content-Type", "application/json")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void test() {

wiremock().register(
post(urlEqualTo(
String.format("/v1beta/models/%s:generateContent?key=%s",
String.format("/v1/models/%s:generateContent?key=%s",
CHAT_MODEL_ID, API_KEY)))
.willReturn(aResponse()
.withHeader("Content-Type", "application/json")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package io.quarkiverse.langchain4j.ai.gemini.deployment;

import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.post;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static org.assertj.core.api.Assertions.assertThat;

import jakarta.inject.Inject;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import com.github.tomakehurst.wiremock.verification.LoggedRequest;

import dev.langchain4j.model.chat.ChatLanguageModel;
import io.quarkiverse.langchain4j.ai.runtime.gemini.AiGeminiChatLanguageModel;
import io.quarkiverse.langchain4j.testing.internal.WiremockAware;
import io.quarkus.arc.ClientProxy;
import io.quarkus.test.QuarkusUnitTest;

public class AiGeminiChatLanguageModelV1BetaSmokeTest extends WiremockAware {

private static final String API_VERSION = "v1Beta";
private static final String API_KEY = "dummy";
private static final String CHAT_MODEL_ID = "gemini-1.5-flash";

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class))
.overrideRuntimeConfigKey("quarkus.langchain4j.ai.gemini.base-url", WiremockAware.wiremockUrlForConfig())
.overrideRuntimeConfigKey("quarkus.langchain4j.ai.gemini.api-version", API_VERSION)
.overrideRuntimeConfigKey("quarkus.langchain4j.ai.gemini.api-key", API_KEY)
.overrideRuntimeConfigKey("quarkus.langchain4j.ai.gemini.log-requests", "true");

@Inject
ChatLanguageModel chatLanguageModel;

@Test
void test() {
assertThat(ClientProxy.unwrap(chatLanguageModel)).isInstanceOf(AiGeminiChatLanguageModel.class);

wiremock().register(
post(urlEqualTo(
String.format("/%s/models/%s:generateContent?key=%s",
API_VERSION, CHAT_MODEL_ID, API_KEY)))
.willReturn(aResponse()
.withHeader("Content-Type", "application/json")
.withBody("""
{
"candidates": [
{
"content": {
"role": "model",
"parts": [
{
"text": "Nice to meet you"
}
]
},
"finishReason": "STOP",
"safetyRatings": [
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.044847902,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.05592617
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.18877223,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.027324531
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.15278918,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.045437217
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.15869519,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.036838707
}
]
}
],
"usageMetadata": {
"promptTokenCount": 11,
"candidatesTokenCount": 37,
"totalTokenCount": 48
}
}
""")));

String response = chatLanguageModel.chat("hello");
assertThat(response).isEqualTo("Nice to meet you");

LoggedRequest loggedRequest = singleLoggedRequest();
assertThat(loggedRequest.getHeader("User-Agent")).isEqualTo("Quarkus REST Client");
String requestBody = new String(loggedRequest.getBody());
assertThat(requestBody).contains("hello");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public class AiGeminiEmbeddingModelAuthProviderSmokeTest extends WiremockAware {
void testBatch() {
wiremock().register(
post(urlEqualTo(
String.format("/v1beta/models/%s:batchEmbedContents", EMBED_MODEL_ID)))
String.format("/v1/models/%s:batchEmbedContents", EMBED_MODEL_ID)))
.withHeader("Authorization", equalTo("Bearer " + API_KEY))
.willReturn(aResponse()
.withHeader("Content-Type", "application/json")
Expand Down Expand Up @@ -89,7 +89,7 @@ void test() {

wiremock().register(
post(urlEqualTo(
String.format("/v1beta/models/%s:embedContent", EMBED_MODEL_ID)))
String.format("/v1/models/%s:embedContent", EMBED_MODEL_ID)))
.withHeader("Authorization", equalTo("Bearer " + API_KEY))
.willReturn(aResponse()
.withHeader("Content-Type", "application/json")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public class AiGeminiEmbeddingModelSmokeTest extends WiremockAware {
void testBatch() {
wiremock().register(
post(urlEqualTo(
String.format("/v1beta/models/%s:batchEmbedContents?key=%s",
String.format("/v1/models/%s:batchEmbedContents?key=%s",
EMBED_MODEL_ID, API_KEY)))
.willReturn(aResponse()
.withHeader("Content-Type", "application/json")
Expand Down Expand Up @@ -87,7 +87,7 @@ void test() {

wiremock().register(
post(urlEqualTo(
String.format("/v1beta/models/%s:embedContent?key=%s",
String.format("/v1/models/%s:embedContent?key=%s",
EMBED_MODEL_ID, API_KEY)))
.willReturn(aResponse()
.withHeader("Content-Type", "application/json")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package io.quarkiverse.langchain4j.ai.gemini.deployment;

import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.post;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static org.assertj.core.api.Assertions.assertThat;

import java.util.List;

import jakarta.inject.Inject;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import io.quarkiverse.langchain4j.ai.runtime.gemini.AiGeminiEmbeddingModel;
import io.quarkiverse.langchain4j.testing.internal.WiremockAware;
import io.quarkus.arc.ClientProxy;
import io.quarkus.test.QuarkusUnitTest;

public class AiGeminiEmbeddingModelV1BetaSmokeTest extends WiremockAware {

private static final String API_VERSION = "v1Beta";
private static final String API_KEY = "dummy";
private static final String EMBED_MODEL_ID = "text-embedding-004";

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class))
.overrideRuntimeConfigKey("quarkus.langchain4j.ai.gemini.base-url", WiremockAware.wiremockUrlForConfig())
.overrideRuntimeConfigKey("quarkus.langchain4j.ai.gemini.api-version", API_VERSION)
.overrideRuntimeConfigKey("quarkus.langchain4j.ai.gemini.api-key", API_KEY)
.overrideRuntimeConfigKey("quarkus.langchain4j.ai.gemini.log-requests", "true");

@Inject
EmbeddingModel embeddingModel;

@Test
void testBatch() {
wiremock().register(
post(urlEqualTo(
String.format("/%s/models/%s:batchEmbedContents?key=%s",
API_VERSION, EMBED_MODEL_ID, API_KEY)))
.willReturn(aResponse()
.withHeader("Content-Type", "application/json")
.withBody("""
{
"embeddings": [
{
"values": [
-0.010632273,
0.019375853,
0.020965198,
0.0007706437,
-0.061464068,
-0.007153866,
-0.028534686
]
},
{
"values": [
0.018468002,
0.0054281265,
-0.017658807,
0.013859263,
0.05341865,
0.026714388,
0.0018762478
]
}
]
}
""")));

List<TextSegment> textSegments = List.of(TextSegment.from("Hello"), TextSegment.from("Bye"));
Response<List<Embedding>> response = embeddingModel.embedAll(textSegments);

assertThat(response.content()).hasSize(2);
}

@Test
void test() {
assertThat(ClientProxy.unwrap(embeddingModel)).isInstanceOf(AiGeminiEmbeddingModel.class);

wiremock().register(
post(urlEqualTo(
String.format("/%s/models/%s:embedContent?key=%s",
API_VERSION, EMBED_MODEL_ID, API_KEY)))
.willReturn(aResponse()
.withHeader("Content-Type", "application/json")
.withBody("""
{
"embedding": {
"values": [
0.013168517,
-0.00871193,
-0.046782672,
0.00069969177,
-0.009518872,
-0.008720178,
0.06010358
]
}
}
""")));

float[] response = embeddingModel.embed("Hello World").content().vector();
assertThat(response).hasSize(7);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ private AiGeminiChatLanguageModel(Builder builder) {

try {
String baseUrl = builder.baseUrl.orElse("https://generativelanguage.googleapis.com");
if (!baseUrl.endsWith("/")) {
baseUrl += "/";
}
baseUrl += builder.apiVersion;
var restApiBuilder = QuarkusRestClientBuilder.newBuilder()
.baseUri(new URI(baseUrl))
.connectTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS)
Expand Down Expand Up @@ -70,6 +74,7 @@ public static final class Builder {

private String configName;
private Optional<String> baseUrl = Optional.empty();
private String apiVersion;
private String modelId;
private String key;
private Double temperature;
Expand All @@ -92,6 +97,11 @@ public Builder baseUrl(Optional<String> baseUrl) {
return this;
}

public Builder apiVersion(String apiVersion) {
this.apiVersion = apiVersion;
return this;
}

public Builder key(String key) {
this.key = key;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ public AiGeminiEmbeddingModel(Builder builder) {

try {
String baseUrl = builder.baseUrl.orElse("https://generativelanguage.googleapis.com");
if (!baseUrl.endsWith("/")) {
baseUrl += "/";
}
baseUrl += builder.apiVersion;
var restApiBuilder = QuarkusRestClientBuilder.newBuilder()
.baseUri(new URI(baseUrl))
.connectTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS)
Expand Down Expand Up @@ -68,6 +72,7 @@ public static Builder builder() {
public static final class Builder {
private String configName;
private Optional<String> baseUrl = Optional.empty();
private String apiVersion;
private String modelId;
private String key;
private Integer dimension;
Expand All @@ -81,6 +86,11 @@ public Builder configName(String configName) {
return this;
}

public Builder apiVersion(String apiVersion) {
this.apiVersion = apiVersion;
return this;
}

public Builder baseUrl(Optional<String> baseUrl) {
this.baseUrl = baseUrl;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public Function<SyntheticCreationalContext<EmbeddingModel>, EmbeddingModel> embe
var builder = AiGeminiEmbeddingModel.builder()
.configName(configName)
.baseUrl(baseUrl)
.apiVersion(aiConfig.apiVersion())
.key(apiKey)
.modelId(embeddingModelConfig.modelId())
.logRequests(firstOrDefault(false, embeddingModelConfig.logRequests(), aiConfig.logRequests()))
Expand Down Expand Up @@ -84,6 +85,7 @@ public Function<SyntheticCreationalContext<ChatLanguageModel>, ChatLanguageModel
String apiKey = aiConfig.apiKey().orElse(null);
var builder = AiGeminiChatLanguageModel.builder()
.baseUrl(baseUrl)
.apiVersion(aiConfig.apiVersion())
.key(apiKey)
.modelId(chatModelConfig.modelId())
.maxOutputTokens(chatModelConfig.maxOutputTokens())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import io.vertx.core.http.HttpClientRequest;
import io.vertx.core.http.HttpClientResponse;

@Path("v1beta/models/")
@Path("models/")
public interface AiGeminiRestApi {

@Path("{modelId}:batchEmbedContents")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ interface AiGeminiConfig {
*/
Optional<String> baseUrl();

/**
* The API version to use for this operation.
*/
@WithDefault("v1")
String apiVersion();

/**
* Whether to enable the integration. Defaults to {@code true}, which means requests are made to the Vertex AI Gemini
* provider.
Expand Down
Loading