Skip to content

Commit 4ba5f76

Browse files
mariofuscogeoand
andcommitted
Make MCP to work with langchain4j agents
Co-authored-by: Georgios Andrianakis <[email protected]>
1 parent 14a428f commit 4ba5f76

File tree

15 files changed

+356
-26
lines changed

15 files changed

+356
-26
lines changed

agentic/deployment/src/main/java/io/quarkiverse/langchain4j/agentic/deployment/AgenticProcessor.java

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,19 @@
2121
import org.jboss.jandex.DotName;
2222
import org.jboss.jandex.IndexView;
2323
import org.jboss.jandex.MethodInfo;
24+
import org.jboss.jandex.ParameterizedType;
25+
import org.jboss.jandex.Type;
2426

27+
import dev.langchain4j.service.IllegalConfigurationException;
2528
import io.quarkiverse.langchain4j.ModelName;
2629
import io.quarkiverse.langchain4j.agentic.runtime.AgenticRecorder;
2730
import io.quarkiverse.langchain4j.agentic.runtime.AiAgentCreateInfo;
31+
import io.quarkiverse.langchain4j.deployment.AnnotationsImpliesAiServiceBuildItem;
32+
import io.quarkiverse.langchain4j.deployment.DotNames;
33+
import io.quarkiverse.langchain4j.deployment.FallbackToDummyUserMessageBuildItem;
2834
import io.quarkiverse.langchain4j.deployment.PreventToolValidationErrorBuildItem;
2935
import io.quarkiverse.langchain4j.deployment.RequestChatModelBeanBuildItem;
36+
import io.quarkiverse.langchain4j.deployment.SkipOutputFormatInstructionsBuildItem;
3037
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
3138
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
3239
import io.quarkus.deployment.annotations.BuildProducer;
@@ -54,7 +61,10 @@ void detectAgents(CombinedIndexBuildItem indexBuildItem, BuildProducer<DetectedA
5461
.filter(m -> Modifier.isStatic(m.flags()) && m.hasAnnotation(
5562
AgenticLangChain4jDotNames.CHAT_MODEL_SUPPLIER))
5663
.findFirst();
57-
producer.produce(new DetectedAiAgentBuildItem(classInfo, methods, chatModelSupplier.orElse(null)));
64+
65+
List<MethodInfo> mcpToolBoxMethods = methods.stream().filter(mi -> mi.hasAnnotation(DotNames.MCP_TOOLBOX)).toList();
66+
producer.produce(
67+
new DetectedAiAgentBuildItem(classInfo, methods, chatModelSupplier.orElse(null), mcpToolBoxMethods));
5868
});
5969
}
6070

@@ -73,6 +83,58 @@ public boolean test(ClassInfo classInfo) {
7383
});
7484
}
7585

86+
@BuildStep
87+
AnnotationsImpliesAiServiceBuildItem implyAiService() {
88+
return new AnnotationsImpliesAiServiceBuildItem(AgenticLangChain4jDotNames.ALL_AGENT_ANNOTATIONS);
89+
}
90+
91+
@BuildStep
92+
SkipOutputFormatInstructionsBuildItem skipOutputInstructions() {
93+
return new SkipOutputFormatInstructionsBuildItem(new Predicate<>() {
94+
@Override
95+
public boolean test(MethodInfo methodInfo) {
96+
for (DotName dotName : AgenticLangChain4jDotNames.ALL_AGENT_ANNOTATIONS) {
97+
if (methodInfo.hasAnnotation(dotName)) {
98+
return true;
99+
}
100+
}
101+
return false;
102+
}
103+
});
104+
}
105+
106+
@BuildStep
107+
FallbackToDummyUserMessageBuildItem fallbackToDummyUserMessage() {
108+
return new FallbackToDummyUserMessageBuildItem(new Predicate<>() {
109+
@Override
110+
public boolean test(MethodInfo methodInfo) {
111+
for (DotName dotName : AgenticLangChain4jDotNames.ALL_AGENT_ANNOTATIONS) {
112+
if (methodInfo.hasAnnotation(dotName)) {
113+
return true;
114+
}
115+
}
116+
return false;
117+
}
118+
});
119+
}
120+
121+
@BuildStep
122+
@Record(ExecutionTime.STATIC_INIT)
123+
void mcpToolBoxSupport(List<DetectedAiAgentBuildItem> detectedAgentBuildItems, AgenticRecorder recorder) {
124+
Set<String> agentsWithMcpToolBox = new HashSet<>();
125+
for (DetectedAiAgentBuildItem bi : detectedAgentBuildItems) {
126+
if (!bi.getMcpToolBoxMethods().isEmpty()) {
127+
if ((bi.getMcpToolBoxMethods().size() != 1) && (bi.getAgenticMethods().size() > 1)) {
128+
throw new IllegalConfigurationException(
129+
"Currently, @McpToolBox can only be used on an Agent if the agent has a single method. This restriction will be lifted in the future. Offending class is '"
130+
+ bi.getIface().name() + "'");
131+
}
132+
agentsWithMcpToolBox.add(bi.getIface().name().toString());
133+
}
134+
}
135+
recorder.setAgentsWithMcpToolBox(agentsWithMcpToolBox);
136+
}
137+
76138
@BuildStep
77139
@Record(ExecutionTime.RUNTIME_INIT)
78140
void cdiSupport(List<DetectedAiAgentBuildItem> detectedAiAgentBuildItems, AgenticRecorder recorder,
@@ -87,6 +149,7 @@ void cdiSupport(List<DetectedAiAgentBuildItem> detectedAiAgentBuildItems, Agenti
87149
AiAgentCreateInfo.ChatModelInfo chatModelInfo = detectedAiAgentBuildItem.getChatModelSupplier() != null
88150
? new AiAgentCreateInfo.ChatModelInfo.FromAnnotation()
89151
: new AiAgentCreateInfo.ChatModelInfo.FromBeanWithName(chatModelName);
152+
90153
SyntheticBeanBuildItem.ExtendedBeanConfigurator beanConfigurator = SyntheticBeanBuildItem
91154
.configure(detectedAiAgentBuildItem.getIface().name())
92155
.forceApplicationClass()
@@ -105,6 +168,8 @@ void cdiSupport(List<DetectedAiAgentBuildItem> detectedAiAgentBuildItems, Agenti
105168
beanConfigurator.addInjectionPoint(
106169
ClassType.create(DotName.createSimple(dev.langchain4j.model.chat.ChatModel.class)), qualifier);
107170
}
171+
beanConfigurator.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
172+
new Type[] { ClassType.create(DotNames.TOOL_PROVIDER) }, null));
108173
syntheticBeanProducer.produce(beanConfigurator.done());
109174
}
110175
requestedChatModelNames.forEach(name -> requestChatModelBeanProducer.produce(new RequestChatModelBeanBuildItem(name)));

agentic/deployment/src/main/java/io/quarkiverse/langchain4j/agentic/deployment/DetectedAiAgentBuildItem.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@ public final class DetectedAiAgentBuildItem extends MultiBuildItem {
1515
private final ClassInfo iface;
1616
private final List<MethodInfo> agenticMethods;
1717
private final MethodInfo chatModelSupplier;
18+
private final List<MethodInfo> mcpToolBoxMethods;
1819

1920
public DetectedAiAgentBuildItem(ClassInfo iface, List<MethodInfo> agenticMethods,
20-
MethodInfo chatModelSupplier) {
21+
MethodInfo chatModelSupplier, List<MethodInfo> mcpToolBoxMethods) {
2122
this.iface = iface;
2223
this.agenticMethods = agenticMethods;
2324
this.chatModelSupplier = chatModelSupplier;
25+
this.mcpToolBoxMethods = mcpToolBoxMethods;
2426
}
2527

2628
public ClassInfo getIface() {
@@ -34,4 +36,8 @@ public List<MethodInfo> getAgenticMethods() {
3436
public MethodInfo getChatModelSupplier() {
3537
return chatModelSupplier;
3638
}
39+
40+
public List<MethodInfo> getMcpToolBoxMethods() {
41+
return mcpToolBoxMethods;
42+
}
3743
}

agentic/runtime/pom.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
<dependency>
1818
<groupId>dev.langchain4j</groupId>
1919
<artifactId>langchain4j-agentic</artifactId>
20+
<version>${langchain4j-agentic.version}</version>
2021
</dependency>
2122
</dependencies>
2223

agentic/runtime/src/main/java/io/quarkiverse/langchain4j/agentic/runtime/AgenticRecorder.java

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,56 @@
11
package io.quarkiverse.langchain4j.agentic.runtime;
22

3+
import java.util.Collections;
4+
import java.util.Set;
5+
import java.util.function.Consumer;
36
import java.util.function.Function;
47

8+
import jakarta.enterprise.inject.Instance;
9+
import jakarta.enterprise.util.TypeLiteral;
10+
511
import org.jboss.logging.Logger;
612

713
import dev.langchain4j.agentic.AgenticServices;
814
import dev.langchain4j.model.chat.ChatModel;
15+
import dev.langchain4j.service.tool.ToolProvider;
916
import io.quarkiverse.langchain4j.ModelName;
1017
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
1118
import io.quarkus.arc.SyntheticCreationalContext;
1219
import io.quarkus.runtime.annotations.Recorder;
20+
import io.quarkus.runtime.annotations.RuntimeInit;
21+
import io.quarkus.runtime.annotations.StaticInit;
1322

1423
@Recorder
1524
public class AgenticRecorder {
1625

1726
private static final Logger log = Logger.getLogger(AgenticRecorder.class);
27+
private static Set<String> agentsWithMcpToolBox = Collections.emptySet();
28+
29+
@StaticInit
30+
public void setAgentsWithMcpToolBox(Set<String> agentsWithMcpToolBox) {
31+
AgenticRecorder.agentsWithMcpToolBox = Collections.unmodifiableSet(agentsWithMcpToolBox);
32+
}
1833

34+
@RuntimeInit
1935
public Function<SyntheticCreationalContext<Object>, Object> createAiAgent(AiAgentCreateInfo info) {
2036
return new Function<>() {
2137
@Override
22-
public Object apply(SyntheticCreationalContext<Object> context) {
38+
public Object apply(SyntheticCreationalContext<Object> cdiContext) {
2339
ChatModel chatModel;
2440
if (info.chatModelInfo() instanceof AiAgentCreateInfo.ChatModelInfo.FromAnnotation) {
2541
chatModel = null;
2642
} else if (info.chatModelInfo() instanceof AiAgentCreateInfo.ChatModelInfo.FromBeanWithName b) {
2743
if (NamedConfigUtil.isDefault(b.name())) {
28-
chatModel = context.getInjectedReference(ChatModel.class);
44+
chatModel = cdiContext.getInjectedReference(ChatModel.class);
2945
} else {
30-
chatModel = context.getInjectedReference(ChatModel.class, ModelName.Literal.of(b.name()));
46+
chatModel = cdiContext.getInjectedReference(ChatModel.class, ModelName.Literal.of(b.name()));
3147
}
3248
} else {
3349
throw new IllegalStateException("Unknown type: " + info.chatModelInfo().getClass());
3450
}
3551

36-
return AgenticServices.createAgenticSystem(loadClassSafe(info), chatModel);
52+
return AgenticServices.createAgenticSystem(loadClassSafe(info), chatModel,
53+
new QuarkusAgenticContextConsumer(cdiContext, info));
3754
}
3855
};
3956
}
@@ -46,4 +63,33 @@ private static Class<?> loadClassSafe(AiAgentCreateInfo info) {
4663
throw new RuntimeException(e);
4764
}
4865
}
66+
67+
private record QuarkusAgenticContextConsumer(SyntheticCreationalContext<Object> cdiContext,
68+
AiAgentCreateInfo aiAgentCreateInfo)
69+
implements
70+
Consumer<AgenticServices.DeclarativeAgentCreationContext> {
71+
72+
private static final TypeLiteral<Instance<ToolProvider>> TOOL_PROVIDER_TYPE_LITERAL = new TypeLiteral<>() {
73+
};
74+
75+
@Override
76+
public void accept(AgenticServices.DeclarativeAgentCreationContext agenticContext) {
77+
if (AgenticRecorder.agentsWithMcpToolBox.contains(agenticContext.agentServiceClass().getName())) {
78+
Instance<ToolProvider> injectedReference = cdiContext.getInjectedReference(TOOL_PROVIDER_TYPE_LITERAL);
79+
if (injectedReference.isResolvable()) {
80+
agenticContext.agentBuilder().toolProvider(injectedReference.get());
81+
}
82+
}
83+
}
84+
}
85+
86+
private static final class NoOpConsumer implements Consumer {
87+
88+
private static final NoOpConsumer INSTANCE = new NoOpConsumer();
89+
90+
@Override
91+
public void accept(Object t) {
92+
93+
}
94+
}
4995
}

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,6 +1147,13 @@ public void markIgnoredAnnotations(BuildProducer<MethodParameterIgnoredAnnotatio
11471147
}));
11481148
}
11491149

1150+
@BuildStep
1151+
AnnotationsImpliesAiServiceBuildItem implyAiService() {
1152+
return new AnnotationsImpliesAiServiceBuildItem(
1153+
List.of(LangChain4jDotNames.SYSTEM_MESSAGE, LangChain4jDotNames.USER_MESSAGE,
1154+
LangChain4jDotNames.MODERATE));
1155+
}
1156+
11501157
@BuildStep
11511158
@Record(ExecutionTime.STATIC_INIT)
11521159
public void handleAiServices(
@@ -1167,7 +1174,10 @@ public void handleAiServices(
11671174
Optional<MetricsCapabilityBuildItem> metricsCapability,
11681175
Capabilities capabilities,
11691176
List<ToolMethodBuildItem> tools,
1170-
List<ToolQualifierProvider.BuildItem> toolQualifierProviderItems) {
1177+
List<ToolQualifierProvider.BuildItem> toolQualifierProviderItems,
1178+
List<AnnotationsImpliesAiServiceBuildItem> annotationsImpliesAiServiceItems,
1179+
List<SkipOutputFormatInstructionsBuildItem> skipOutputFormatInstructionsItems,
1180+
List<FallbackToDummyUserMessageBuildItem> fallbackToDummyUserMessageItems) {
11711181

11721182
IndexView index = indexBuildItem.getIndex();
11731183

@@ -1211,7 +1221,8 @@ public void handleAiServices(
12111221

12121222
Set<String> detectedForCreate = new HashSet<>(nameToUsed.keySet());
12131223
addCreatedAware(index, detectedForCreate);
1214-
addIfacesWithMessageAnns(index, detectedForCreate);
1224+
addIfacesWithMessageAnns(index, annotationsImpliesAiServiceItems.stream()
1225+
.flatMap(bi -> bi.getAnnotationNames().stream()).collect(Collectors.toList()), detectedForCreate);
12151226
Set<String> registeredAiServiceClassNames = declarativeAiServiceItems.stream()
12161227
.map(bi -> bi.getServiceClassInfo().name().toString()).collect(
12171228
Collectors.toUnmodifiableSet());
@@ -1342,7 +1353,13 @@ public void handleAiServices(
13421353
config.responseSchema(),
13431354
allowedPredicates,
13441355
ignoredPredicates,
1345-
tools, toolQualifierProviderItems);
1356+
tools, toolQualifierProviderItems,
1357+
skipOutputFormatInstructionsItems.stream().map(
1358+
SkipOutputFormatInstructionsBuildItem::getPredicate)
1359+
.reduce(mi -> false, Predicate::or),
1360+
fallbackToDummyUserMessageItems.stream().map(
1361+
FallbackToDummyUserMessageBuildItem::getPredicate)
1362+
.reduce(mi -> false, Predicate::or));
13461363
if (!methodCreateInfo.getToolClassInfo().isEmpty()) {
13471364
if ((matchingBI != null)
13481365
&& matchingBI.getChatMemoryProviderSupplierClassDotName() == null) {
@@ -1482,9 +1499,7 @@ private String createMethodId(MethodInfo methodInfo) {
14821499
+ Arrays.toString(methodInfo.parameters().stream().map(mp -> mp.type().name().toString()).toArray()) + ')';
14831500
}
14841501

1485-
private void addIfacesWithMessageAnns(IndexView index, Set<String> detectedForCreate) {
1486-
List<DotName> annotations = List.of(LangChain4jDotNames.SYSTEM_MESSAGE, LangChain4jDotNames.USER_MESSAGE,
1487-
LangChain4jDotNames.MODERATE);
1502+
private void addIfacesWithMessageAnns(IndexView index, List<DotName> annotations, Set<String> detectedForCreate) {
14881503
for (DotName annotation : annotations) {
14891504
Collection<AnnotationInstance> instances = index.getAnnotations(annotation);
14901505
for (AnnotationInstance instance : instances) {
@@ -1522,7 +1537,9 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(
15221537
Collection<Predicate<AnnotationInstance>> allowedPredicates,
15231538
Collection<Predicate<AnnotationInstance>> ignoredPredicates,
15241539
List<ToolMethodBuildItem> tools,
1525-
List<ToolQualifierProvider.BuildItem> toolQualifierProviders) {
1540+
List<ToolQualifierProvider.BuildItem> toolQualifierProviders,
1541+
Predicate<MethodInfo> skipOutputFormatInstructionsPredicate,
1542+
Predicate<MethodInfo> fallbackToDummyUserMessagePredicate) {
15261543
validateReturnType(method);
15271544

15281545
boolean requiresModeration = method.hasAnnotation(LangChain4jDotNames.MODERATE);
@@ -1532,15 +1549,17 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(
15321549

15331550
// TODO give user ability to provide custom OutputParser
15341551
String outputFormatInstructions = "";
1535-
Optional<JsonSchema> structuredOutputSchema = Optional.empty();
1536-
if (!returnType.equals(Multi.class)) {
1537-
outputFormatInstructions = SERVICE_OUTPUT_PARSER.outputFormatInstructions(returnType);
1552+
if (!skipOutputFormatInstructionsPredicate.test(method)) {
1553+
Optional<JsonSchema> structuredOutputSchema = Optional.empty();
1554+
if (!returnType.equals(Multi.class)) {
1555+
outputFormatInstructions = SERVICE_OUTPUT_PARSER.outputFormatInstructions(returnType);
1556+
}
15381557
}
15391558

15401559
List<TemplateParameterInfo> templateParams = gatherTemplateParamInfo(params, allowedPredicates, ignoredPredicates);
15411560
Optional<AiServiceMethodCreateInfo.TemplateInfo> systemMessageInfo = gatherSystemMessageInfo(method, templateParams);
15421561
AiServiceMethodCreateInfo.UserMessageInfo userMessageInfo = gatherUserMessageInfo(method, templateParams,
1543-
systemMessageInfo);
1562+
systemMessageInfo, fallbackToDummyUserMessagePredicate);
15441563

15451564
AiServiceMethodCreateInfo.ResponseSchemaInfo responseSchemaInfo = ResponseSchemaInfo.of(generateResponseSchema,
15461565
systemMessageInfo,
@@ -1797,7 +1816,8 @@ private Optional<Integer> gatherOverrideChatModelParameterPosition(MethodInfo me
17971816

17981817
private AiServiceMethodCreateInfo.UserMessageInfo gatherUserMessageInfo(MethodInfo method,
17991818
List<TemplateParameterInfo> templateParams,
1800-
Optional<AiServiceMethodCreateInfo.TemplateInfo> systemMessageInfo) {
1819+
Optional<AiServiceMethodCreateInfo.TemplateInfo> systemMessageInfo,
1820+
Predicate<MethodInfo> fallbackToDummyUserMesage) {
18011821

18021822
Optional<Integer> userNameParamPosition = method.annotations(LangChain4jDotNames.USER_NAME).stream().filter(
18031823
IS_METHOD_PARAMETER_ANNOTATION).map(METHOD_PARAMETER_POSITION_FUNCTION).findFirst();
@@ -1874,6 +1894,14 @@ private AiServiceMethodCreateInfo.UserMessageInfo gatherUserMessageInfo(MethodIn
18741894
return AiServiceMethodCreateInfo.UserMessageInfo.fromMethodParam(0, userNameParamPosition,
18751895
imageParamPosition, audioParamPosition, pdfParamPosition);
18761896
}
1897+
1898+
if (fallbackToDummyUserMesage.test(method)) {
1899+
return AiServiceMethodCreateInfo.UserMessageInfo.fromTemplate(
1900+
AiServiceMethodCreateInfo.TemplateInfo.fromText("", Map.of()), Optional.empty(),
1901+
Optional.empty(),
1902+
Optional.empty(), Optional.empty());
1903+
}
1904+
18771905
throw illegalConfigurationForMethod(
18781906
"For methods with multiple parameters, each parameter must be annotated with @V (or match an template parameter by name), @UserMessage, @UserName or @MemoryId",
18791907
method);
@@ -2128,8 +2156,7 @@ private List<String> gatherMethodToolClassNames(MethodInfo method) {
21282156
}
21292157

21302158
private List<String> gatherMethodMcpClientNames(MethodInfo method) {
2131-
// Using the class name to keep the McpToolBox annotation in the mcp module
2132-
AnnotationInstance mcpToolBoxInstance = method.declaredAnnotation("io.quarkiverse.langchain4j.mcp.runtime.McpToolBox");
2159+
AnnotationInstance mcpToolBoxInstance = method.declaredAnnotation(DotNames.MCP_TOOLBOX);
21332160
if (mcpToolBoxInstance == null) {
21342161
return null;
21352162
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package io.quarkiverse.langchain4j.deployment;
2+
3+
import java.util.List;
4+
5+
import org.jboss.jandex.DotName;
6+
7+
import io.quarkus.builder.item.MultiBuildItem;
8+
9+
/**
10+
* A build item that can be used in order to make an interface an AiService automatically
11+
*/
12+
public final class AnnotationsImpliesAiServiceBuildItem extends MultiBuildItem {
13+
14+
private final List<DotName> annotationNames;
15+
16+
public AnnotationsImpliesAiServiceBuildItem(List<DotName> annotationNames) {
17+
this.annotationNames = annotationNames;
18+
}
19+
20+
public List<DotName> getAnnotationNames() {
21+
return annotationNames;
22+
}
23+
}

0 commit comments

Comments
 (0)