Skip to content

Commit 2162f89

Browse files
Doris26copybara-github
authored andcommitted
feat: Introduce ExampleTool for few-shot examples in LlmAgent
This change adds a new ExampleTool to the ADK, allowing users to inject few-shot examples into the LLM request's system instructions. Examples can be provided inline within the agent's YAML configuration or by referencing a BaseExampleProvider instance. This tool enhances the ability to guide the LLM's behavior by providing contextual examples. The CL includes the tool implementation, updates to the configuration handling, new example agent configurations, and comprehensive unit tests for the new functionality. PiperOrigin-RevId: 805017868
1 parent 8e10df2 commit 2162f89

File tree

4 files changed

+625
-0
lines changed

4 files changed

+625
-0
lines changed

core/src/main/java/com/google/adk/agents/LlmAgent.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ public enum IncludeContents {
100100
private final List<Object> toolsUnion;
101101
private final ImmutableList<BaseToolset> toolsets;
102102
private final Optional<GenerateContentConfig> generateContentConfig;
103+
// TODO: Remove exampleProvider field - examples should only be provided via ExampleTool
103104
private final Optional<BaseExampleProvider> exampleProvider;
104105
private final IncludeContents includeContents;
105106

@@ -280,6 +281,8 @@ public Builder generateContentConfig(GenerateContentConfig generateContentConfig
280281
return this;
281282
}
282283

284+
// TODO: Remove these example provider methods and only use ExampleTool for providing examples.
285+
// Direct example methods should be deprecated in favor of using ExampleTool consistently.
283286
@CanIgnoreReturnValue
284287
public Builder exampleProvider(BaseExampleProvider exampleProvider) {
285288
this.exampleProvider = exampleProvider;
@@ -789,6 +792,7 @@ public Optional<GenerateContentConfig> generateContentConfig() {
789792
return generateContentConfig;
790793
}
791794

795+
// TODO: Remove this getter - examples should only be provided via ExampleTool
792796
public Optional<BaseExampleProvider> exampleProvider() {
793797
return exampleProvider;
794798
}
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.adk.tools;
18+
19+
import static com.google.common.base.Strings.isNullOrEmpty;
20+
import static com.google.common.collect.ImmutableList.toImmutableList;
21+
22+
import com.fasterxml.jackson.databind.ObjectMapper;
23+
import com.google.adk.JsonBaseModel;
24+
import com.google.adk.agents.ConfigAgentUtils.ConfigurationException;
25+
import com.google.adk.examples.BaseExampleProvider;
26+
import com.google.adk.examples.Example;
27+
import com.google.adk.examples.ExampleUtils;
28+
import com.google.adk.models.LlmRequest;
29+
import com.google.common.collect.ImmutableList;
30+
import com.google.errorprone.annotations.CanIgnoreReturnValue;
31+
import com.google.genai.types.Content;
32+
import io.reactivex.rxjava3.core.Completable;
33+
import java.lang.reflect.Field;
34+
import java.lang.reflect.Modifier;
35+
import java.util.ArrayList;
36+
import java.util.List;
37+
import java.util.Map;
38+
import java.util.Optional;
39+
40+
/**
41+
* A tool that injects (few-shot) examples into the outgoing LLM request as system instructions.
42+
*
43+
* <p>Configuration (args) options for YAML:
44+
*
45+
* <ul>
46+
* <li><b>examples</b>: Either a fully-qualified reference to a {@link BaseExampleProvider}
47+
* instance (e.g., <code>com.example.MyExamples.INSTANCE</code>) or a list of examples with
48+
* fields <code>input</code> and <code>output</code> (array of messages).
49+
* </ul>
50+
*/
51+
public final class ExampleTool extends BaseTool {
52+
53+
private static final ObjectMapper MAPPER = JsonBaseModel.getMapper();
54+
55+
private final Optional<BaseExampleProvider> exampleProvider;
56+
private final Optional<List<Example>> examples;
57+
58+
/** Single private constructor; create via builder or fromConfig. */
59+
private ExampleTool(Builder builder) {
60+
super(
61+
isNullOrEmpty(builder.name) ? "example_tool" : builder.name,
62+
isNullOrEmpty(builder.description)
63+
? "Adds few-shot examples to the request"
64+
: builder.description);
65+
this.exampleProvider = builder.provider;
66+
this.examples = builder.examples.isEmpty() ? Optional.empty() : Optional.of(builder.examples);
67+
}
68+
69+
@Override
70+
public Completable processLlmRequest(
71+
LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) {
72+
// Do not add anything if no user text
73+
String query =
74+
toolContext
75+
.userContent()
76+
.flatMap(content -> content.parts().flatMap(parts -> parts.stream().findFirst()))
77+
.flatMap(part -> part.text())
78+
.orElse("");
79+
if (query.isEmpty()) {
80+
return Completable.complete();
81+
}
82+
83+
final String examplesBlock;
84+
if (exampleProvider.isPresent()) {
85+
examplesBlock = ExampleUtils.buildExampleSi(exampleProvider.get(), query);
86+
} else if (examples.isPresent()) {
87+
// Adapter provider that returns a fixed list irrespective of query
88+
BaseExampleProvider provider = q -> examples.get();
89+
examplesBlock = ExampleUtils.buildExampleSi(provider, query);
90+
} else {
91+
return Completable.complete();
92+
}
93+
94+
llmRequestBuilder.appendInstructions(ImmutableList.of(examplesBlock));
95+
// Delegate to BaseTool to keep any declaration bookkeeping (none for this tool)
96+
return super.processLlmRequest(llmRequestBuilder, toolContext);
97+
}
98+
99+
/** Factory from YAML tool args. */
100+
public static ExampleTool fromConfig(ToolArgsConfig args, String configAbsPath)
101+
throws ConfigurationException {
102+
if (args == null || args.isEmpty()) {
103+
throw new ConfigurationException("ExampleTool requires 'examples' argument");
104+
}
105+
Object examplesArg = args.get("examples");
106+
if (examplesArg == null) {
107+
throw new ConfigurationException("ExampleTool missing 'examples' argument");
108+
}
109+
110+
try {
111+
if (examplesArg instanceof String string) {
112+
BaseExampleProvider provider = resolveExampleProvider(string);
113+
return ExampleTool.builder().setExampleProvider(provider).build();
114+
}
115+
if (examplesArg instanceof List) {
116+
@SuppressWarnings("unchecked")
117+
List<Object> rawList = (List<Object>) examplesArg;
118+
List<Example> examples = new ArrayList<>();
119+
for (Object o : rawList) {
120+
if (!(o instanceof Map)) {
121+
throw new ConfigurationException(
122+
"Invalid example entry. Expected a map with 'input' and 'output'.");
123+
}
124+
@SuppressWarnings("unchecked")
125+
Map<String, Object> m = (Map<String, Object>) o;
126+
Object inputObj = m.get("input");
127+
Object outputObj = m.get("output");
128+
if (inputObj == null || outputObj == null) {
129+
throw new ConfigurationException("Each example must include 'input' and 'output'.");
130+
}
131+
Content input = MAPPER.convertValue(inputObj, Content.class);
132+
@SuppressWarnings("unchecked")
133+
List<Object> outList = (List<Object>) outputObj;
134+
ImmutableList<Content> outputs =
135+
outList.stream()
136+
.map(e -> MAPPER.convertValue(e, Content.class))
137+
.collect(toImmutableList());
138+
examples.add(Example.builder().input(input).output(outputs).build());
139+
}
140+
Builder b = ExampleTool.builder();
141+
for (Example ex : examples) {
142+
b.addExample(ex);
143+
}
144+
return b.build();
145+
}
146+
} catch (RuntimeException e) {
147+
throw new ConfigurationException("Failed to parse ExampleTool examples", e);
148+
}
149+
throw new ConfigurationException(
150+
"Unsupported 'examples' type. Provide a string provider ref or list of examples.");
151+
}
152+
153+
/** Overload to match resolver which passes only ToolArgsConfig. */
154+
public static ExampleTool fromConfig(ToolArgsConfig args) throws ConfigurationException {
155+
return fromConfig(args, /* configAbsPath= */ "");
156+
}
157+
158+
private static BaseExampleProvider resolveExampleProvider(String ref)
159+
throws ConfigurationException {
160+
int lastDot = ref.lastIndexOf('.');
161+
if (lastDot <= 0) {
162+
throw new ConfigurationException(
163+
"Invalid example provider reference: " + ref + ". Expected ClassName.FIELD");
164+
}
165+
String className = ref.substring(0, lastDot);
166+
String fieldName = ref.substring(lastDot + 1);
167+
try {
168+
Class<?> clazz = Thread.currentThread().getContextClassLoader().loadClass(className);
169+
Field field = clazz.getField(fieldName);
170+
if (!Modifier.isStatic(field.getModifiers())) {
171+
throw new ConfigurationException(
172+
"Field '" + fieldName + "' in class '" + className + "' is not static");
173+
}
174+
Object instance = field.get(null);
175+
if (instance instanceof BaseExampleProvider provider) {
176+
return provider;
177+
}
178+
throw new ConfigurationException(
179+
"Field '" + fieldName + "' in class '" + className + "' is not a BaseExampleProvider");
180+
} catch (NoSuchFieldException e) {
181+
throw new ConfigurationException(
182+
"Field '" + fieldName + "' not found in class '" + className + "'", e);
183+
} catch (ClassNotFoundException e) {
184+
throw new ConfigurationException("Example provider class not found: " + className, e);
185+
} catch (IllegalAccessException e) {
186+
throw new ConfigurationException("Cannot access example provider field: " + ref, e);
187+
}
188+
}
189+
190+
public static Builder builder() {
191+
return new Builder();
192+
}
193+
194+
public static final class Builder {
195+
private final List<Example> examples = new ArrayList<>();
196+
private String name = "example_tool";
197+
private String description = "Adds few-shot examples to the request";
198+
private Optional<BaseExampleProvider> provider = Optional.empty();
199+
200+
@CanIgnoreReturnValue
201+
public Builder setName(String name) {
202+
this.name = name;
203+
return this;
204+
}
205+
206+
@CanIgnoreReturnValue
207+
public Builder setDescription(String description) {
208+
this.description = description;
209+
return this;
210+
}
211+
212+
@CanIgnoreReturnValue
213+
public Builder addExample(Example ex) {
214+
this.examples.add(ex);
215+
return this;
216+
}
217+
218+
@CanIgnoreReturnValue
219+
public Builder setExampleProvider(BaseExampleProvider provider) {
220+
this.provider = Optional.ofNullable(provider);
221+
return this;
222+
}
223+
224+
public ExampleTool build() {
225+
return new ExampleTool(this);
226+
}
227+
}
228+
}

core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,17 @@
2020
import static org.junit.Assert.assertThrows;
2121

2222
import com.google.adk.agents.ConfigAgentUtils.ConfigurationException;
23+
import com.google.adk.examples.Example;
24+
import com.google.adk.models.LlmRequest;
25+
import com.google.adk.testing.TestUtils;
26+
import com.google.adk.tools.ExampleTool;
27+
import com.google.adk.tools.ToolContext;
2328
import com.google.adk.tools.mcp.McpToolset;
29+
import com.google.adk.utils.ComponentRegistry;
30+
import com.google.common.collect.ImmutableList;
31+
import com.google.genai.types.Content;
2432
import com.google.genai.types.GenerateContentConfig;
33+
import com.google.genai.types.Part;
2534
import java.io.File;
2635
import java.io.IOException;
2736
import java.nio.file.Files;
@@ -814,4 +823,102 @@ public void fromConfig_withOutputKeyAndOtherFields_parsesAllFields()
814823
assertThat(llmAgent.disallowTransferToPeers()).isFalse();
815824
assertThat(llmAgent.model()).isPresent();
816825
}
826+
827+
@Test
828+
public void fromConfig_withGenerateContentConfigSafetySettings()
829+
throws IOException, ConfigurationException {
830+
File configFile = tempFolder.newFile("generate_content_config_safety.yaml");
831+
Files.writeString(
832+
configFile.toPath(),
833+
"""
834+
agent_class: LlmAgent
835+
model: gemini-2.5-flash
836+
name: root_agent
837+
description: dice agent
838+
instruction: You are a helpful assistant
839+
generate_content_config:
840+
safety_settings:
841+
- category: HARM_CATEGORY_DANGEROUS_CONTENT
842+
threshold: 'OFF'
843+
""");
844+
String configPath = configFile.getAbsolutePath();
845+
846+
BaseAgent agent = ConfigAgentUtils.fromConfig(configPath);
847+
848+
assertThat(agent).isInstanceOf(LlmAgent.class);
849+
LlmAgent llmAgent = (LlmAgent) agent;
850+
assertThat(llmAgent.name()).isEqualTo("root_agent");
851+
assertThat(llmAgent.description()).isEqualTo("dice agent");
852+
assertThat(llmAgent.model()).isPresent();
853+
assertThat(llmAgent.model().get().modelName()).hasValue("gemini-2.5-flash");
854+
855+
assertThat(llmAgent.generateContentConfig()).isPresent();
856+
GenerateContentConfig config = llmAgent.generateContentConfig().get();
857+
assertThat(config).isNotNull();
858+
assertThat(config.safetySettings()).isPresent();
859+
assertThat(config.safetySettings().get()).hasSize(1);
860+
861+
// Verify the safety settings are parsed correctly
862+
assertThat(config.safetySettings().get().get(0).category()).isPresent();
863+
assertThat(config.safetySettings().get().get(0).category().get().toString())
864+
.isEqualTo("HARM_CATEGORY_DANGEROUS_CONTENT");
865+
assertThat(config.safetySettings().get().get(0).threshold()).isPresent();
866+
assertThat(config.safetySettings().get().get(0).threshold().get().toString()).isEqualTo("OFF");
867+
}
868+
869+
@Test
870+
public void fromConfig_withExamplesList_appendsExamplesInFlow()
871+
throws IOException, ConfigurationException {
872+
// Register an ExampleTool instance under short name used by YAML
873+
ComponentRegistry originalRegistry = ComponentRegistry.getInstance();
874+
class TestRegistry extends ComponentRegistry {
875+
TestRegistry() {
876+
super();
877+
}
878+
}
879+
ComponentRegistry testRegistry = new TestRegistry();
880+
Example example =
881+
Example.builder()
882+
.input(Content.fromParts(Part.fromText("qin")))
883+
.output(ImmutableList.of(Content.fromParts(Part.fromText("qout"))))
884+
.build();
885+
testRegistry.register(
886+
"multi_agent_llm_config.example_tool", ExampleTool.builder().addExample(example).build());
887+
ComponentRegistry.setInstance(testRegistry);
888+
File configFile = tempFolder.newFile("with_examples.yaml");
889+
Files.writeString(
890+
configFile.toPath(),
891+
"""
892+
name: examples_agent
893+
description: Agent with examples configured via tool
894+
instruction: You are a test agent
895+
agent_class: LlmAgent
896+
model: gemini-2.0-flash
897+
tools:
898+
- name: multi_agent_llm_config.example_tool
899+
""");
900+
String configPath = configFile.getAbsolutePath();
901+
902+
BaseAgent agent;
903+
try {
904+
agent = ConfigAgentUtils.fromConfig(configPath);
905+
} finally {
906+
ComponentRegistry.setInstance(originalRegistry);
907+
}
908+
909+
assertThat(agent).isInstanceOf(LlmAgent.class);
910+
LlmAgent llmAgent = (LlmAgent) agent;
911+
912+
// Process tools to verify ExampleTool appends the examples to the request
913+
LlmRequest.Builder requestBuilder = LlmRequest.builder().model("gemini-2.0-flash");
914+
InvocationContext context = TestUtils.createInvocationContext(agent);
915+
llmAgent
916+
.canonicalTools(new ReadonlyContext(context))
917+
.concatMapCompletable(
918+
tool -> tool.processLlmRequest(requestBuilder, ToolContext.builder(context).build()))
919+
.blockingAwait();
920+
LlmRequest updated = requestBuilder.build();
921+
// Verify ExampleTool appended a system instruction with examples
922+
assertThat(updated.getSystemInstructions()).isNotEmpty();
923+
}
817924
}

0 commit comments

Comments
 (0)