From b1c680b38833b5d1febf945543c89fad8c6a8e59 Mon Sep 17 00:00:00 2001 From: mariofusco Date: Wed, 10 Sep 2025 14:39:55 +0200 Subject: [PATCH] Make MCP to work with langchain4j agents Co-authored-by: Georgios Andrianakis --- .../agentic/deployment/AgenticProcessor.java | 67 +++++++++- .../deployment/DetectedAiAgentBuildItem.java | 8 +- agentic/runtime/pom.xml | 1 + .../agentic/runtime/AgenticRecorder.java | 54 +++++++- .../deployment/AiServicesProcessor.java | 55 ++++++-- .../AnnotationsImpliesAiServiceBuildItem.java | 23 ++++ .../langchain4j/deployment/DotNames.java | 4 + .../FallbackToDummyUserMessageBuildItem.java | 23 ++++ .../langchain4j/deployment/JandexUtil.java | 3 +- .../PreventToolValidationErrorBuildItem.java | 2 +- ...SkipOutputFormatInstructionsBuildItem.java | 23 ++++ mcp/deployment/pom.xml | 24 ++++ .../mcp/deployment/McpProcessor.java | 5 +- .../mcp/test/AgentMcpClientTest.java | 121 ++++++++++++++++++ pom.xml | 6 + 15 files changed, 393 insertions(+), 26 deletions(-) create mode 100644 core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AnnotationsImpliesAiServiceBuildItem.java create mode 100644 core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/FallbackToDummyUserMessageBuildItem.java create mode 100644 core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/SkipOutputFormatInstructionsBuildItem.java create mode 100644 mcp/deployment/src/test/java/io/quarkiverse/langchain4j/mcp/test/AgentMcpClientTest.java diff --git a/agentic/deployment/src/main/java/io/quarkiverse/langchain4j/agentic/deployment/AgenticProcessor.java b/agentic/deployment/src/main/java/io/quarkiverse/langchain4j/agentic/deployment/AgenticProcessor.java index 844b62d17..067f0113b 100644 --- a/agentic/deployment/src/main/java/io/quarkiverse/langchain4j/agentic/deployment/AgenticProcessor.java +++ b/agentic/deployment/src/main/java/io/quarkiverse/langchain4j/agentic/deployment/AgenticProcessor.java @@ -21,12 +21,19 @@ import org.jboss.jandex.DotName; import org.jboss.jandex.IndexView; import org.jboss.jandex.MethodInfo; +import org.jboss.jandex.ParameterizedType; +import org.jboss.jandex.Type; +import dev.langchain4j.service.IllegalConfigurationException; import io.quarkiverse.langchain4j.ModelName; import io.quarkiverse.langchain4j.agentic.runtime.AgenticRecorder; import io.quarkiverse.langchain4j.agentic.runtime.AiAgentCreateInfo; +import io.quarkiverse.langchain4j.deployment.AnnotationsImpliesAiServiceBuildItem; +import io.quarkiverse.langchain4j.deployment.DotNames; +import io.quarkiverse.langchain4j.deployment.FallbackToDummyUserMessageBuildItem; import io.quarkiverse.langchain4j.deployment.PreventToolValidationErrorBuildItem; import io.quarkiverse.langchain4j.deployment.RequestChatModelBeanBuildItem; +import io.quarkiverse.langchain4j.deployment.SkipOutputFormatInstructionsBuildItem; import io.quarkiverse.langchain4j.runtime.NamedConfigUtil; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; import io.quarkus.deployment.annotations.BuildProducer; @@ -54,7 +61,10 @@ void detectAgents(CombinedIndexBuildItem indexBuildItem, BuildProducer Modifier.isStatic(m.flags()) && m.hasAnnotation( AgenticLangChain4jDotNames.CHAT_MODEL_SUPPLIER)) .findFirst(); - producer.produce(new DetectedAiAgentBuildItem(classInfo, methods, chatModelSupplier.orElse(null))); + + List mcpToolBoxMethods = methods.stream().filter(mi -> mi.hasAnnotation(DotNames.MCP_TOOLBOX)).toList(); + producer.produce( + new DetectedAiAgentBuildItem(classInfo, methods, chatModelSupplier.orElse(null), mcpToolBoxMethods)); }); } @@ -73,6 +83,58 @@ public boolean test(ClassInfo classInfo) { }); } + @BuildStep + AnnotationsImpliesAiServiceBuildItem implyAiService() { + return new AnnotationsImpliesAiServiceBuildItem(AgenticLangChain4jDotNames.ALL_AGENT_ANNOTATIONS); + } + + @BuildStep + SkipOutputFormatInstructionsBuildItem skipOutputInstructions() { + return new SkipOutputFormatInstructionsBuildItem(new Predicate<>() { + @Override + public boolean test(MethodInfo methodInfo) { + for (DotName dotName : AgenticLangChain4jDotNames.ALL_AGENT_ANNOTATIONS) { + if (methodInfo.hasAnnotation(dotName)) { + return true; + } + } + return false; + } + }); + } + + @BuildStep + FallbackToDummyUserMessageBuildItem fallbackToDummyUserMessage() { + return new FallbackToDummyUserMessageBuildItem(new Predicate<>() { + @Override + public boolean test(MethodInfo methodInfo) { + for (DotName dotName : AgenticLangChain4jDotNames.ALL_AGENT_ANNOTATIONS) { + if (methodInfo.hasAnnotation(dotName)) { + return true; + } + } + return false; + } + }); + } + + @BuildStep + @Record(ExecutionTime.STATIC_INIT) + void mcpToolBoxSupport(List detectedAgentBuildItems, AgenticRecorder recorder) { + Set agentsWithMcpToolBox = new HashSet<>(); + for (DetectedAiAgentBuildItem bi : detectedAgentBuildItems) { + if (!bi.getMcpToolBoxMethods().isEmpty()) { + if ((bi.getMcpToolBoxMethods().size() != 1) && (bi.getAgenticMethods().size() > 1)) { + throw new IllegalConfigurationException( + "Currently, @McpToolBox can only be used on an Agent if the agent has a single method. This restriction will be lifted in the future. Offending class is '" + + bi.getIface().name() + "'"); + } + agentsWithMcpToolBox.add(bi.getIface().name().toString()); + } + } + recorder.setAgentsWithMcpToolBox(agentsWithMcpToolBox); + } + @BuildStep @Record(ExecutionTime.RUNTIME_INIT) void cdiSupport(List detectedAiAgentBuildItems, AgenticRecorder recorder, @@ -87,6 +149,7 @@ void cdiSupport(List detectedAiAgentBuildItems, Agenti AiAgentCreateInfo.ChatModelInfo chatModelInfo = detectedAiAgentBuildItem.getChatModelSupplier() != null ? new AiAgentCreateInfo.ChatModelInfo.FromAnnotation() : new AiAgentCreateInfo.ChatModelInfo.FromBeanWithName(chatModelName); + SyntheticBeanBuildItem.ExtendedBeanConfigurator beanConfigurator = SyntheticBeanBuildItem .configure(detectedAiAgentBuildItem.getIface().name()) .forceApplicationClass() @@ -105,6 +168,8 @@ void cdiSupport(List detectedAiAgentBuildItems, Agenti beanConfigurator.addInjectionPoint( ClassType.create(DotName.createSimple(dev.langchain4j.model.chat.ChatModel.class)), qualifier); } + beanConfigurator.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE, + new Type[] { ClassType.create(DotNames.TOOL_PROVIDER) }, null)); syntheticBeanProducer.produce(beanConfigurator.done()); } requestedChatModelNames.forEach(name -> requestChatModelBeanProducer.produce(new RequestChatModelBeanBuildItem(name))); diff --git a/agentic/deployment/src/main/java/io/quarkiverse/langchain4j/agentic/deployment/DetectedAiAgentBuildItem.java b/agentic/deployment/src/main/java/io/quarkiverse/langchain4j/agentic/deployment/DetectedAiAgentBuildItem.java index 01f71810c..f3f0b7ad7 100644 --- a/agentic/deployment/src/main/java/io/quarkiverse/langchain4j/agentic/deployment/DetectedAiAgentBuildItem.java +++ b/agentic/deployment/src/main/java/io/quarkiverse/langchain4j/agentic/deployment/DetectedAiAgentBuildItem.java @@ -15,12 +15,14 @@ public final class DetectedAiAgentBuildItem extends MultiBuildItem { private final ClassInfo iface; private final List agenticMethods; private final MethodInfo chatModelSupplier; + private final List mcpToolBoxMethods; public DetectedAiAgentBuildItem(ClassInfo iface, List agenticMethods, - MethodInfo chatModelSupplier) { + MethodInfo chatModelSupplier, List mcpToolBoxMethods) { this.iface = iface; this.agenticMethods = agenticMethods; this.chatModelSupplier = chatModelSupplier; + this.mcpToolBoxMethods = mcpToolBoxMethods; } public ClassInfo getIface() { @@ -34,4 +36,8 @@ public List getAgenticMethods() { public MethodInfo getChatModelSupplier() { return chatModelSupplier; } + + public List getMcpToolBoxMethods() { + return mcpToolBoxMethods; + } } diff --git a/agentic/runtime/pom.xml b/agentic/runtime/pom.xml index a02557312..e8cba702b 100644 --- a/agentic/runtime/pom.xml +++ b/agentic/runtime/pom.xml @@ -17,6 +17,7 @@ dev.langchain4j langchain4j-agentic + ${langchain4j-agentic.version} diff --git a/agentic/runtime/src/main/java/io/quarkiverse/langchain4j/agentic/runtime/AgenticRecorder.java b/agentic/runtime/src/main/java/io/quarkiverse/langchain4j/agentic/runtime/AgenticRecorder.java index 1fe392937..94c5a067d 100644 --- a/agentic/runtime/src/main/java/io/quarkiverse/langchain4j/agentic/runtime/AgenticRecorder.java +++ b/agentic/runtime/src/main/java/io/quarkiverse/langchain4j/agentic/runtime/AgenticRecorder.java @@ -1,39 +1,56 @@ package io.quarkiverse.langchain4j.agentic.runtime; +import java.util.Collections; +import java.util.Set; +import java.util.function.Consumer; import java.util.function.Function; +import jakarta.enterprise.inject.Instance; +import jakarta.enterprise.util.TypeLiteral; + import org.jboss.logging.Logger; import dev.langchain4j.agentic.AgenticServices; import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.service.tool.ToolProvider; import io.quarkiverse.langchain4j.ModelName; import io.quarkiverse.langchain4j.runtime.NamedConfigUtil; import io.quarkus.arc.SyntheticCreationalContext; import io.quarkus.runtime.annotations.Recorder; +import io.quarkus.runtime.annotations.RuntimeInit; +import io.quarkus.runtime.annotations.StaticInit; @Recorder public class AgenticRecorder { private static final Logger log = Logger.getLogger(AgenticRecorder.class); + private static Set agentsWithMcpToolBox = Collections.emptySet(); + + @StaticInit + public void setAgentsWithMcpToolBox(Set agentsWithMcpToolBox) { + AgenticRecorder.agentsWithMcpToolBox = Collections.unmodifiableSet(agentsWithMcpToolBox); + } + @RuntimeInit public Function, Object> createAiAgent(AiAgentCreateInfo info) { return new Function<>() { @Override - public Object apply(SyntheticCreationalContext context) { + public Object apply(SyntheticCreationalContext cdiContext) { ChatModel chatModel; if (info.chatModelInfo() instanceof AiAgentCreateInfo.ChatModelInfo.FromAnnotation) { chatModel = null; } else if (info.chatModelInfo() instanceof AiAgentCreateInfo.ChatModelInfo.FromBeanWithName b) { if (NamedConfigUtil.isDefault(b.name())) { - chatModel = context.getInjectedReference(ChatModel.class); + chatModel = cdiContext.getInjectedReference(ChatModel.class); } else { - chatModel = context.getInjectedReference(ChatModel.class, ModelName.Literal.of(b.name())); + chatModel = cdiContext.getInjectedReference(ChatModel.class, ModelName.Literal.of(b.name())); } } else { throw new IllegalStateException("Unknown type: " + info.chatModelInfo().getClass()); } - return AgenticServices.createAgenticSystem(loadClassSafe(info), chatModel); + return AgenticServices.createAgenticSystem(loadClassSafe(info), chatModel, + new QuarkusAgenticContextConsumer(cdiContext, info)); } }; } @@ -46,4 +63,33 @@ private static Class loadClassSafe(AiAgentCreateInfo info) { throw new RuntimeException(e); } } + + private record QuarkusAgenticContextConsumer(SyntheticCreationalContext cdiContext, + AiAgentCreateInfo aiAgentCreateInfo) + implements + Consumer { + + private static final TypeLiteral> TOOL_PROVIDER_TYPE_LITERAL = new TypeLiteral<>() { + }; + + @Override + public void accept(AgenticServices.DeclarativeAgentCreationContext agenticContext) { + if (AgenticRecorder.agentsWithMcpToolBox.contains(agenticContext.agentServiceClass().getName())) { + Instance injectedReference = cdiContext.getInjectedReference(TOOL_PROVIDER_TYPE_LITERAL); + if (injectedReference.isResolvable()) { + agenticContext.agentBuilder().toolProvider(injectedReference.get()); + } + } + } + } + + private static final class NoOpConsumer implements Consumer { + + private static final NoOpConsumer INSTANCE = new NoOpConsumer(); + + @Override + public void accept(Object t) { + + } + } } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index 7ae59a008..8ed3b6b06 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -1147,6 +1147,13 @@ public void markIgnoredAnnotations(BuildProducer metricsCapability, Capabilities capabilities, List tools, - List toolQualifierProviderItems) { + List toolQualifierProviderItems, + List annotationsImpliesAiServiceItems, + List skipOutputFormatInstructionsItems, + List fallbackToDummyUserMessageItems) { IndexView index = indexBuildItem.getIndex(); @@ -1211,7 +1221,8 @@ public void handleAiServices( Set detectedForCreate = new HashSet<>(nameToUsed.keySet()); addCreatedAware(index, detectedForCreate); - addIfacesWithMessageAnns(index, detectedForCreate); + addIfacesWithMessageAnns(index, annotationsImpliesAiServiceItems.stream() + .flatMap(bi -> bi.getAnnotationNames().stream()).collect(Collectors.toList()), detectedForCreate); Set registeredAiServiceClassNames = declarativeAiServiceItems.stream() .map(bi -> bi.getServiceClassInfo().name().toString()).collect( Collectors.toUnmodifiableSet()); @@ -1342,7 +1353,13 @@ public void handleAiServices( config.responseSchema(), allowedPredicates, ignoredPredicates, - tools, toolQualifierProviderItems); + tools, toolQualifierProviderItems, + skipOutputFormatInstructionsItems.stream().map( + SkipOutputFormatInstructionsBuildItem::getPredicate) + .reduce(mi -> false, Predicate::or), + fallbackToDummyUserMessageItems.stream().map( + FallbackToDummyUserMessageBuildItem::getPredicate) + .reduce(mi -> false, Predicate::or)); if (!methodCreateInfo.getToolClassInfo().isEmpty()) { if ((matchingBI != null) && matchingBI.getChatMemoryProviderSupplierClassDotName() == null) { @@ -1482,9 +1499,7 @@ private String createMethodId(MethodInfo methodInfo) { + Arrays.toString(methodInfo.parameters().stream().map(mp -> mp.type().name().toString()).toArray()) + ')'; } - private void addIfacesWithMessageAnns(IndexView index, Set detectedForCreate) { - List annotations = List.of(LangChain4jDotNames.SYSTEM_MESSAGE, LangChain4jDotNames.USER_MESSAGE, - LangChain4jDotNames.MODERATE); + private void addIfacesWithMessageAnns(IndexView index, List annotations, Set detectedForCreate) { for (DotName annotation : annotations) { Collection instances = index.getAnnotations(annotation); for (AnnotationInstance instance : instances) { @@ -1522,7 +1537,9 @@ private AiServiceMethodCreateInfo gatherMethodMetadata( Collection> allowedPredicates, Collection> ignoredPredicates, List tools, - List toolQualifierProviders) { + List toolQualifierProviders, + Predicate skipOutputFormatInstructionsPredicate, + Predicate fallbackToDummyUserMessagePredicate) { validateReturnType(method); boolean requiresModeration = method.hasAnnotation(LangChain4jDotNames.MODERATE); @@ -1532,15 +1549,17 @@ private AiServiceMethodCreateInfo gatherMethodMetadata( // TODO give user ability to provide custom OutputParser String outputFormatInstructions = ""; - Optional structuredOutputSchema = Optional.empty(); - if (!returnType.equals(Multi.class)) { - outputFormatInstructions = SERVICE_OUTPUT_PARSER.outputFormatInstructions(returnType); + if (!skipOutputFormatInstructionsPredicate.test(method)) { + Optional structuredOutputSchema = Optional.empty(); + if (!returnType.equals(Multi.class)) { + outputFormatInstructions = SERVICE_OUTPUT_PARSER.outputFormatInstructions(returnType); + } } List templateParams = gatherTemplateParamInfo(params, allowedPredicates, ignoredPredicates); Optional systemMessageInfo = gatherSystemMessageInfo(method, templateParams); AiServiceMethodCreateInfo.UserMessageInfo userMessageInfo = gatherUserMessageInfo(method, templateParams, - systemMessageInfo); + systemMessageInfo, fallbackToDummyUserMessagePredicate); AiServiceMethodCreateInfo.ResponseSchemaInfo responseSchemaInfo = ResponseSchemaInfo.of(generateResponseSchema, systemMessageInfo, @@ -1797,7 +1816,8 @@ private Optional gatherOverrideChatModelParameterPosition(MethodInfo me private AiServiceMethodCreateInfo.UserMessageInfo gatherUserMessageInfo(MethodInfo method, List templateParams, - Optional systemMessageInfo) { + Optional systemMessageInfo, + Predicate fallbackToDummyUserMesage) { Optional userNameParamPosition = method.annotations(LangChain4jDotNames.USER_NAME).stream().filter( IS_METHOD_PARAMETER_ANNOTATION).map(METHOD_PARAMETER_POSITION_FUNCTION).findFirst(); @@ -1874,6 +1894,14 @@ private AiServiceMethodCreateInfo.UserMessageInfo gatherUserMessageInfo(MethodIn return AiServiceMethodCreateInfo.UserMessageInfo.fromMethodParam(0, userNameParamPosition, imageParamPosition, audioParamPosition, pdfParamPosition); } + + if (fallbackToDummyUserMesage.test(method)) { + return AiServiceMethodCreateInfo.UserMessageInfo.fromTemplate( + AiServiceMethodCreateInfo.TemplateInfo.fromText("", Map.of()), Optional.empty(), + Optional.empty(), + Optional.empty(), Optional.empty()); + } + throw illegalConfigurationForMethod( "For methods with multiple parameters, each parameter must be annotated with @V (or match an template parameter by name), @UserMessage, @UserName or @MemoryId", method); @@ -2128,8 +2156,7 @@ private List gatherMethodToolClassNames(MethodInfo method) { } private List gatherMethodMcpClientNames(MethodInfo method) { - // Using the class name to keep the McpToolBox annotation in the mcp module - AnnotationInstance mcpToolBoxInstance = method.declaredAnnotation("io.quarkiverse.langchain4j.mcp.runtime.McpToolBox"); + AnnotationInstance mcpToolBoxInstance = method.declaredAnnotation(DotNames.MCP_TOOLBOX); if (mcpToolBoxInstance == null) { return null; } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AnnotationsImpliesAiServiceBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AnnotationsImpliesAiServiceBuildItem.java new file mode 100644 index 000000000..c3302fb9d --- /dev/null +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AnnotationsImpliesAiServiceBuildItem.java @@ -0,0 +1,23 @@ +package io.quarkiverse.langchain4j.deployment; + +import java.util.List; + +import org.jboss.jandex.DotName; + +import io.quarkus.builder.item.MultiBuildItem; + +/** + * A build item that can be used in order to make an interface an AiService automatically + */ +public final class AnnotationsImpliesAiServiceBuildItem extends MultiBuildItem { + + private final List annotationNames; + + public AnnotationsImpliesAiServiceBuildItem(List annotationNames) { + this.annotationNames = annotationNames; + } + + public List getAnnotationNames() { + return annotationNames; + } +} diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java index da6f08056..4d9434b7c 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java @@ -21,6 +21,7 @@ import dev.langchain4j.agent.tool.Tool; import dev.langchain4j.model.chat.listener.ChatModelListener; +import dev.langchain4j.service.tool.ToolProvider; import io.quarkiverse.langchain4j.auth.ModelAuthProvider; import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator; import io.quarkiverse.langchain4j.response.AiResponseAugmenter; @@ -80,6 +81,9 @@ public class DotNames { public static final DotName CHAT_MODEL_LISTENER = DotName.createSimple(ChatModelListener.class); public static final DotName MODEL_AUTH_PROVIDER = DotName.createSimple(ModelAuthProvider.class); public static final DotName TOOL = DotName.createSimple(Tool.class); + // Using the class name to keep the McpToolBox annotation in the mcp module + public static final DotName MCP_TOOLBOX = DotName.createSimple("io.quarkiverse.langchain4j.mcp.runtime.McpToolBox"); + public static final DotName TOOL_PROVIDER = DotName.createSimple(ToolProvider.class); public static final DotName CHAT_EVENT = DotName.createSimple(ChatEvent.class); public static final DotName REGISTER_REST_CLIENT = DotName.createSimple(RegisterRestClient.class); diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/FallbackToDummyUserMessageBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/FallbackToDummyUserMessageBuildItem.java new file mode 100644 index 000000000..08be4f8be --- /dev/null +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/FallbackToDummyUserMessageBuildItem.java @@ -0,0 +1,23 @@ +package io.quarkiverse.langchain4j.deployment; + +import java.util.function.Predicate; + +import org.jboss.jandex.MethodInfo; + +import io.quarkus.builder.item.MultiBuildItem; + +/** + * A build item which indicates that a dummy user message should be created if otherwise the processing would fail + */ +public final class FallbackToDummyUserMessageBuildItem extends MultiBuildItem { + + private final Predicate predicate; + + public FallbackToDummyUserMessageBuildItem(Predicate predicate) { + this.predicate = predicate; + } + + public Predicate getPredicate() { + return predicate; + } +} diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/JandexUtil.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/JandexUtil.java index 636d54a11..9211d1fb5 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/JandexUtil.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/JandexUtil.java @@ -46,8 +46,9 @@ static Collection getAllSuperinterfaces(ClassInfo interfaceName, Inde if (classInfo == null) { log.warn("'" + name + "' used for creating an AiService is not an interface. Attempting to create an AiService using this class will fail"); + } else { + directSuperInterfaces.add(classInfo); } - directSuperInterfaces.add(classInfo); } for (ClassInfo directSubInterface : directSuperInterfaces) { result.add(directSubInterface); diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/PreventToolValidationErrorBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/PreventToolValidationErrorBuildItem.java index 4b247501e..e8373f7a1 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/PreventToolValidationErrorBuildItem.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/PreventToolValidationErrorBuildItem.java @@ -7,7 +7,7 @@ import io.quarkus.builder.item.MultiBuildItem; /** - * TODO + * A build item that prevents the default validation exception from being thrown for invalid AiService methods */ public final class PreventToolValidationErrorBuildItem extends MultiBuildItem { diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/SkipOutputFormatInstructionsBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/SkipOutputFormatInstructionsBuildItem.java new file mode 100644 index 000000000..b41c245a4 --- /dev/null +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/SkipOutputFormatInstructionsBuildItem.java @@ -0,0 +1,23 @@ +package io.quarkiverse.langchain4j.deployment; + +import java.util.function.Predicate; + +import org.jboss.jandex.MethodInfo; + +import io.quarkus.builder.item.MultiBuildItem; + +/** + * A build item which indicated when an AI Service method should not have output instructions associated with it + */ +public final class SkipOutputFormatInstructionsBuildItem extends MultiBuildItem { + + private final Predicate predicate; + + public SkipOutputFormatInstructionsBuildItem(Predicate predicate) { + this.predicate = predicate; + } + + public Predicate getPredicate() { + return predicate; + } +} diff --git a/mcp/deployment/pom.xml b/mcp/deployment/pom.xml index 26f176c73..ce671af11 100644 --- a/mcp/deployment/pom.xml +++ b/mcp/deployment/pom.xml @@ -42,6 +42,30 @@ quarkus-junit5-internal test + + io.quarkiverse.langchain4j + quarkus-langchain4j-agentic + ${project.version} + test + + + io.quarkiverse.langchain4j + quarkus-langchain4j-openai-deployment + test + ${project.version} + + + io.quarkiverse.langchain4j + quarkus-langchain4j-openai-testing-internal + test + ${project.version} + + + org.wiremock + wiremock-standalone + ${wiremock.version} + compile + org.assertj assertj-core diff --git a/mcp/deployment/src/main/java/io/quarkiverse/langchain4j/mcp/deployment/McpProcessor.java b/mcp/deployment/src/main/java/io/quarkiverse/langchain4j/mcp/deployment/McpProcessor.java index dc6fcca52..646333d3e 100644 --- a/mcp/deployment/src/main/java/io/quarkiverse/langchain4j/mcp/deployment/McpProcessor.java +++ b/mcp/deployment/src/main/java/io/quarkiverse/langchain4j/mcp/deployment/McpProcessor.java @@ -26,7 +26,6 @@ import com.fasterxml.jackson.databind.ObjectMapper; import dev.langchain4j.mcp.client.McpClient; -import dev.langchain4j.service.tool.ToolProvider; import io.opentelemetry.api.trace.Tracer; import io.quarkiverse.langchain4j.deployment.DotNames; import io.quarkiverse.langchain4j.mcp.auth.McpClientAuthProvider; @@ -58,7 +57,6 @@ public class McpProcessor { private static final DotName MCP_CLIENT = DotName.createSimple(McpClient.class); private static final DotName MCP_CLIENT_NAME = DotName.createSimple(McpClientName.class); - private static final DotName TOOL_PROVIDER = DotName.createSimple(ToolProvider.class); private static final DotName TRACER = DotName.createSimple(Tracer.class); @SuppressWarnings({ "rawtypes", "unchecked" }) @@ -158,8 +156,7 @@ public void registerMcpClients(McpBuildTimeConfiguration mcpBuildTimeConfigurati // generate a tool provider if configured to do so if (mcpBuildTimeConfiguration.generateToolProvider().orElse(true)) { SyntheticBeanBuildItem.ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem - .configure(TOOL_PROVIDER) - .addType(ClassType.create(TOOL_PROVIDER)) + .configure(DotNames.TOOL_PROVIDER) .setRuntimeInit() .defaultBean() .unremovable() diff --git a/mcp/deployment/src/test/java/io/quarkiverse/langchain4j/mcp/test/AgentMcpClientTest.java b/mcp/deployment/src/test/java/io/quarkiverse/langchain4j/mcp/test/AgentMcpClientTest.java new file mode 100644 index 000000000..e75c81f9a --- /dev/null +++ b/mcp/deployment/src/test/java/io/quarkiverse/langchain4j/mcp/test/AgentMcpClientTest.java @@ -0,0 +1,121 @@ +package io.quarkiverse.langchain4j.mcp.test; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.concurrent.CopyOnWriteArrayList; + +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.asset.StringAsset; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.agentic.Agent; +import dev.langchain4j.agentic.declarative.SequenceAgent; +import dev.langchain4j.agentic.declarative.SubAgent; +import dev.langchain4j.model.chat.listener.ChatModelListener; +import dev.langchain4j.model.chat.listener.ChatModelRequestContext; +import dev.langchain4j.service.V; +import io.quarkiverse.langchain4j.mcp.runtime.McpToolBox; +import io.quarkiverse.langchain4j.openai.testing.internal.OpenAiBaseTest; +import io.quarkiverse.langchain4j.testing.internal.WiremockAware; +import io.quarkus.test.QuarkusUnitTest; + +/** + * Test MCP clients over an HTTP transport. + * This is a very rudimentary test that runs against a mock MCP server. The plan is + * to replace it with a more proper MCP server once we have an appropriate Java SDK ready for it. + */ +public class AgentMcpClientTest extends OpenAiBaseTest { + + @RegisterExtension + static QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(AbstractMockHttpMcpServer.class, MockHttpMcpServer.class, Sequence.class, + AgentWithMcpTools.class) + .addAsResource(new StringAsset(""" + quarkus.langchain4j.mcp.client1.transport-type=http + quarkus.langchain4j.mcp.client1.url=http://localhost:8081/mock-mcp/sse + quarkus.log.category."dev.langchain4j".level=DEBUG + quarkus.log.category."io.quarkiverse".level=DEBUG + quarkus.langchain4j.mcp.client1.tool-execution-timeout=1s + """), + "application.properties")) + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.api-key", "whatever") + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.base-url", + WiremockAware.wiremockUrlForConfig("/v1")); + + @Inject + Sequence sequence; + + public interface Sequence { + + @SequenceAgent(outputName = "toolsList", subAgents = { + @SubAgent(type = AgentWithMcpTools.class, outputName = "toolsList") + }) + String toolsList(@V("userMessage") String userMessage); + } + + public interface AgentWithMcpTools { + + @Agent + @McpToolBox + String toolsList(@V("userMessage") String userMessage); + } + + @Test + @ActivateRequestContext + public void agentHasTools() { + wiremock().register( + post(urlEqualTo("/v1/chat/completions")) + .willReturn( + aResponse() + .withHeader("Content-Type", "application/json") + .withBody(""" + { + "id": "chatcmpl-8GRu6o9Qf9JFAebDqpj76H5fl6Naz", + "object": "chat.completion", + "created": 1698931202, + "model": "gpt-3.5-turbo-0613", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "dummy" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 159, + "completion_tokens": 13, + "total_tokens": 172 + } + } + \s"""))); + + var response = sequence.toolsList("test"); + assertEquals("dummy", response); + assertThat(ToolsInterceptor.TOOLS_NAMES).hasSize(3); + } + + @Singleton + public static class ToolsInterceptor implements ChatModelListener { + + public static final CopyOnWriteArrayList TOOLS_NAMES = new CopyOnWriteArrayList<>(); + + @Override + public void onRequest(ChatModelRequestContext requestContext) { + requestContext.chatRequest().toolSpecifications().forEach(ts -> TOOLS_NAMES.add(ts.name())); + } + } +} diff --git a/pom.xml b/pom.xml index 813bf6618..60decf515 100644 --- a/pom.xml +++ b/pom.xml @@ -39,6 +39,7 @@ ${quarkus.version} 1.4.0 1.4.0-beta10 + 1.4.1-beta10 1.4.0-beta10 1.4.0-beta10 2.2.0 @@ -68,6 +69,11 @@ pom import + + dev.langchain4j + langchain4j-agentic + ${langchain4j-agentic.version} + io.quarkus quarkus-extension-processor