Skip to content

Commit 8e10df2

Browse files
google-genai-botcopybara-github
authored andcommitted
fix!: Allow beforeModelCallback to modify the LLM request
PiperOrigin-RevId: 804942304
1 parent c150d52 commit 8e10df2

File tree

4 files changed

+37
-11
lines changed

4 files changed

+37
-11
lines changed

core/src/main/java/com/google/adk/agents/Callbacks.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ public interface BeforeModelCallback extends BeforeModelCallbackBase {
3636
* Async callback before LLM invocation.
3737
*
3838
* @param callbackContext Callback context.
39-
* @param llmRequest LLM request.
39+
* @param llmRequestBuilder LLM request builder.
4040
* @return response override, or empty to continue.
4141
*/
42-
Maybe<LlmResponse> call(CallbackContext callbackContext, LlmRequest llmRequest);
42+
Maybe<LlmResponse> call(CallbackContext callbackContext, LlmRequest.Builder llmRequestBuilder);
4343
}
4444

4545
/**
@@ -48,7 +48,8 @@ public interface BeforeModelCallback extends BeforeModelCallbackBase {
4848
*/
4949
@FunctionalInterface
5050
public interface BeforeModelCallbackSync extends BeforeModelCallbackBase {
51-
Optional<LlmResponse> call(CallbackContext callbackContext, LlmRequest llmRequest);
51+
Optional<LlmResponse> call(
52+
CallbackContext callbackContext, LlmRequest.Builder llmRequestBuilder);
5253
}
5354

5455
interface AfterModelCallbackBase {}

core/src/main/java/com/google/adk/agents/LlmAgent.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -360,9 +360,10 @@ public Builder beforeModelCallback(List<BeforeModelCallbackBase> beforeModelCall
360360
} else if (callback instanceof BeforeModelCallbackSync beforeModelCallbackSyncInstance) {
361361
builder.add(
362362
(BeforeModelCallback)
363-
(callbackContext, llmRequest) ->
363+
(callbackContext, llmRequestBuilder) ->
364364
Maybe.fromOptional(
365-
beforeModelCallbackSyncInstance.call(callbackContext, llmRequest)));
365+
beforeModelCallbackSyncInstance.call(
366+
callbackContext, llmRequestBuilder)));
366367
} else {
367368
logger.warn(
368369
"Invalid beforeModelCallback callback type: %s. Ignoring this callback.",
@@ -379,8 +380,9 @@ public Builder beforeModelCallback(List<BeforeModelCallbackBase> beforeModelCall
379380
public Builder beforeModelCallbackSync(BeforeModelCallbackSync beforeModelCallbackSync) {
380381
this.beforeModelCallback =
381382
ImmutableList.of(
382-
(callbackContext, llmRequest) ->
383-
Maybe.fromOptional(beforeModelCallbackSync.call(callbackContext, llmRequest)));
383+
(callbackContext, llmRequestBuilder) ->
384+
Maybe.fromOptional(
385+
beforeModelCallbackSync.call(callbackContext, llmRequestBuilder)));
384386
return this;
385387
}
386388

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,9 @@ private Flowable<LlmResponse> callLlm(
209209
InvocationContext context, LlmRequest llmRequest, Event eventForCallbackUsage) {
210210
LlmAgent agent = (LlmAgent) context.agent();
211211

212-
return handleBeforeModelCallback(context, llmRequest, eventForCallbackUsage)
212+
LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder();
213+
214+
return handleBeforeModelCallback(context, llmRequestBuilder, eventForCallbackUsage)
213215
.flatMapPublisher(
214216
beforeResponse -> {
215217
if (beforeResponse.isPresent()) {
@@ -226,7 +228,7 @@ private Flowable<LlmResponse> callLlm(
226228

227229
try (Scope scope = llmCallSpan.makeCurrent()) {
228230
return llm.generateContent(
229-
llmRequest,
231+
llmRequestBuilder.build(),
230232
context.runConfig().streamingMode() == StreamingMode.SSE)
231233
.doOnNext(
232234
llmResp -> {
@@ -257,7 +259,7 @@ private Flowable<LlmResponse> callLlm(
257259
* @return A {@link Single} with the callback result or {@link Optional#empty()}.
258260
*/
259261
private Single<Optional<LlmResponse>> handleBeforeModelCallback(
260-
InvocationContext context, LlmRequest llmRequest, Event modelResponseEvent) {
262+
InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent) {
261263
LlmAgent agent = (LlmAgent) context.agent();
262264

263265
Optional<List<BeforeModelCallback>> callbacksOpt = agent.beforeModelCallback();
@@ -274,7 +276,7 @@ private Single<Optional<LlmResponse>> handleBeforeModelCallback(
274276
CallbackContext callbackContext =
275277
new CallbackContext(context, callbackEvent.actions());
276278
return callback
277-
.call(callbackContext, llmRequest)
279+
.call(callbackContext, llmRequestBuilder)
278280
.map(Optional::of)
279281
.defaultIfEmpty(Optional.empty());
280282
})

core/src/test/java/com/google/adk/agents/CallbacksTest.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,27 @@ public void testRun_withBeforeModelCallback_returnsResponseFromCallback() {
396396
assertThat(getOnlyElement(events).content()).hasValue(callbackContent);
397397
}
398398

399+
@Test
400+
public void testRun_withBeforeModelCallback_usesModifiedRequestFromCallback() {
401+
TestLlm testLlm = createTestLlm(createLlmResponse(Content.builder().build()));
402+
LlmAgent agent =
403+
createTestAgentBuilder(testLlm)
404+
.beforeModelCallback(
405+
(context, requestBuilder) -> {
406+
requestBuilder.contents(
407+
ImmutableList.of(Content.fromParts(Part.fromText("Modified request"))));
408+
return Maybe.empty();
409+
})
410+
.build();
411+
InvocationContext invocationContext = createInvocationContext(agent);
412+
413+
List<Event> unused = agent.runAsync(invocationContext).toList().blockingGet();
414+
415+
assertThat(testLlm.getRequests()).hasSize(1);
416+
assertThat(testLlm.getRequests().get(0).contents())
417+
.containsExactly(Content.fromParts(Part.fromText("Modified request")));
418+
}
419+
399420
@Test
400421
public void testRun_withAfterModelCallback_returnsResponseFromCallback() {
401422
Part textPartFromModel = Part.fromText("Real LLM response");

0 commit comments

Comments
 (0)