Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,19 @@
import org.jboss.jandex.DotName;
import org.jboss.jandex.IndexView;
import org.jboss.jandex.MethodInfo;
import org.jboss.jandex.ParameterizedType;
import org.jboss.jandex.Type;

import dev.langchain4j.service.IllegalConfigurationException;
import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.agentic.runtime.AgenticRecorder;
import io.quarkiverse.langchain4j.agentic.runtime.AiAgentCreateInfo;
import io.quarkiverse.langchain4j.deployment.AnnotationsImpliesAiServiceBuildItem;
import io.quarkiverse.langchain4j.deployment.DotNames;
import io.quarkiverse.langchain4j.deployment.FallbackToDummyUserMessageBuildItem;
import io.quarkiverse.langchain4j.deployment.PreventToolValidationErrorBuildItem;
import io.quarkiverse.langchain4j.deployment.RequestChatModelBeanBuildItem;
import io.quarkiverse.langchain4j.deployment.SkipOutputFormatInstructionsBuildItem;
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.deployment.annotations.BuildProducer;
Expand Down Expand Up @@ -54,7 +61,10 @@ void detectAgents(CombinedIndexBuildItem indexBuildItem, BuildProducer<DetectedA
.filter(m -> Modifier.isStatic(m.flags()) && m.hasAnnotation(
AgenticLangChain4jDotNames.CHAT_MODEL_SUPPLIER))
.findFirst();
producer.produce(new DetectedAiAgentBuildItem(classInfo, methods, chatModelSupplier.orElse(null)));

List<MethodInfo> mcpToolBoxMethods = methods.stream().filter(mi -> mi.hasAnnotation(DotNames.MCP_TOOLBOX)).toList();
producer.produce(
new DetectedAiAgentBuildItem(classInfo, methods, chatModelSupplier.orElse(null), mcpToolBoxMethods));
});
}

Expand All @@ -73,6 +83,58 @@ public boolean test(ClassInfo classInfo) {
});
}

@BuildStep
AnnotationsImpliesAiServiceBuildItem implyAiService() {
return new AnnotationsImpliesAiServiceBuildItem(AgenticLangChain4jDotNames.ALL_AGENT_ANNOTATIONS);
}

@BuildStep
SkipOutputFormatInstructionsBuildItem skipOutputInstructions() {
return new SkipOutputFormatInstructionsBuildItem(new Predicate<>() {
@Override
public boolean test(MethodInfo methodInfo) {
for (DotName dotName : AgenticLangChain4jDotNames.ALL_AGENT_ANNOTATIONS) {
if (methodInfo.hasAnnotation(dotName)) {
return true;
}
}
return false;
}
});
}

@BuildStep
FallbackToDummyUserMessageBuildItem fallbackToDummyUserMessage() {
return new FallbackToDummyUserMessageBuildItem(new Predicate<>() {
@Override
public boolean test(MethodInfo methodInfo) {
for (DotName dotName : AgenticLangChain4jDotNames.ALL_AGENT_ANNOTATIONS) {
if (methodInfo.hasAnnotation(dotName)) {
return true;
}
}
return false;
}
});
}

@BuildStep
@Record(ExecutionTime.STATIC_INIT)
void mcpToolBoxSupport(List<DetectedAiAgentBuildItem> detectedAgentBuildItems, AgenticRecorder recorder) {
Set<String> agentsWithMcpToolBox = new HashSet<>();
for (DetectedAiAgentBuildItem bi : detectedAgentBuildItems) {
if (!bi.getMcpToolBoxMethods().isEmpty()) {
if ((bi.getMcpToolBoxMethods().size() != 1) && (bi.getAgenticMethods().size() > 1)) {
throw new IllegalConfigurationException(
"Currently, @McpToolBox can only be used on an Agent if the agent has a single method. This restriction will be lifted in the future. Offending class is '"
+ bi.getIface().name() + "'");
}
agentsWithMcpToolBox.add(bi.getIface().name().toString());
}
}
recorder.setAgentsWithMcpToolBox(agentsWithMcpToolBox);
}

@BuildStep
@Record(ExecutionTime.RUNTIME_INIT)
void cdiSupport(List<DetectedAiAgentBuildItem> detectedAiAgentBuildItems, AgenticRecorder recorder,
Expand All @@ -87,6 +149,7 @@ void cdiSupport(List<DetectedAiAgentBuildItem> detectedAiAgentBuildItems, Agenti
AiAgentCreateInfo.ChatModelInfo chatModelInfo = detectedAiAgentBuildItem.getChatModelSupplier() != null
? new AiAgentCreateInfo.ChatModelInfo.FromAnnotation()
: new AiAgentCreateInfo.ChatModelInfo.FromBeanWithName(chatModelName);

SyntheticBeanBuildItem.ExtendedBeanConfigurator beanConfigurator = SyntheticBeanBuildItem
.configure(detectedAiAgentBuildItem.getIface().name())
.forceApplicationClass()
Expand All @@ -105,6 +168,8 @@ void cdiSupport(List<DetectedAiAgentBuildItem> detectedAiAgentBuildItems, Agenti
beanConfigurator.addInjectionPoint(
ClassType.create(DotName.createSimple(dev.langchain4j.model.chat.ChatModel.class)), qualifier);
}
beanConfigurator.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
new Type[] { ClassType.create(DotNames.TOOL_PROVIDER) }, null));
syntheticBeanProducer.produce(beanConfigurator.done());
}
requestedChatModelNames.forEach(name -> requestChatModelBeanProducer.produce(new RequestChatModelBeanBuildItem(name)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ public final class DetectedAiAgentBuildItem extends MultiBuildItem {
private final ClassInfo iface;
private final List<MethodInfo> agenticMethods;
private final MethodInfo chatModelSupplier;
private final List<MethodInfo> mcpToolBoxMethods;

public DetectedAiAgentBuildItem(ClassInfo iface, List<MethodInfo> agenticMethods,
MethodInfo chatModelSupplier) {
MethodInfo chatModelSupplier, List<MethodInfo> mcpToolBoxMethods) {
this.iface = iface;
this.agenticMethods = agenticMethods;
this.chatModelSupplier = chatModelSupplier;
this.mcpToolBoxMethods = mcpToolBoxMethods;
}

public ClassInfo getIface() {
Expand All @@ -34,4 +36,8 @@ public List<MethodInfo> getAgenticMethods() {
public MethodInfo getChatModelSupplier() {
return chatModelSupplier;
}

public List<MethodInfo> getMcpToolBoxMethods() {
return mcpToolBoxMethods;
}
}
1 change: 1 addition & 0 deletions agentic/runtime/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-agentic</artifactId>
<version>${langchain4j-agentic.version}</version>
</dependency>
</dependencies>

Expand Down
Original file line number Diff line number Diff line change
@@ -1,39 +1,56 @@
package io.quarkiverse.langchain4j.agentic.runtime;

import java.util.Collections;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;

import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.util.TypeLiteral;

import org.jboss.logging.Logger;

import dev.langchain4j.agentic.AgenticServices;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.service.tool.ToolProvider;
import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
import io.quarkus.arc.SyntheticCreationalContext;
import io.quarkus.runtime.annotations.Recorder;
import io.quarkus.runtime.annotations.RuntimeInit;
import io.quarkus.runtime.annotations.StaticInit;

@Recorder
public class AgenticRecorder {

private static final Logger log = Logger.getLogger(AgenticRecorder.class);
private static Set<String> agentsWithMcpToolBox = Collections.emptySet();

@StaticInit
public void setAgentsWithMcpToolBox(Set<String> agentsWithMcpToolBox) {
AgenticRecorder.agentsWithMcpToolBox = Collections.unmodifiableSet(agentsWithMcpToolBox);
}

@RuntimeInit
public Function<SyntheticCreationalContext<Object>, Object> createAiAgent(AiAgentCreateInfo info) {
return new Function<>() {
@Override
public Object apply(SyntheticCreationalContext<Object> context) {
public Object apply(SyntheticCreationalContext<Object> cdiContext) {
ChatModel chatModel;
if (info.chatModelInfo() instanceof AiAgentCreateInfo.ChatModelInfo.FromAnnotation) {
chatModel = null;
} else if (info.chatModelInfo() instanceof AiAgentCreateInfo.ChatModelInfo.FromBeanWithName b) {
if (NamedConfigUtil.isDefault(b.name())) {
chatModel = context.getInjectedReference(ChatModel.class);
chatModel = cdiContext.getInjectedReference(ChatModel.class);
} else {
chatModel = context.getInjectedReference(ChatModel.class, ModelName.Literal.of(b.name()));
chatModel = cdiContext.getInjectedReference(ChatModel.class, ModelName.Literal.of(b.name()));
}
} else {
throw new IllegalStateException("Unknown type: " + info.chatModelInfo().getClass());
}

return AgenticServices.createAgenticSystem(loadClassSafe(info), chatModel);
return AgenticServices.createAgenticSystem(loadClassSafe(info), chatModel,
new QuarkusAgenticContextConsumer(cdiContext, info));
}
};
}
Expand All @@ -46,4 +63,33 @@ private static Class<?> loadClassSafe(AiAgentCreateInfo info) {
throw new RuntimeException(e);
}
}

private record QuarkusAgenticContextConsumer(SyntheticCreationalContext<Object> cdiContext,
AiAgentCreateInfo aiAgentCreateInfo)
implements
Consumer<AgenticServices.DeclarativeAgentCreationContext> {

private static final TypeLiteral<Instance<ToolProvider>> TOOL_PROVIDER_TYPE_LITERAL = new TypeLiteral<>() {
};

@Override
public void accept(AgenticServices.DeclarativeAgentCreationContext agenticContext) {
if (AgenticRecorder.agentsWithMcpToolBox.contains(agenticContext.agentServiceClass().getName())) {
Instance<ToolProvider> injectedReference = cdiContext.getInjectedReference(TOOL_PROVIDER_TYPE_LITERAL);
if (injectedReference.isResolvable()) {
agenticContext.agentBuilder().toolProvider(injectedReference.get());
}
}
}
}

private static final class NoOpConsumer implements Consumer {

private static final NoOpConsumer INSTANCE = new NoOpConsumer();

@Override
public void accept(Object t) {

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,13 @@ public void markIgnoredAnnotations(BuildProducer<MethodParameterIgnoredAnnotatio
}));
}

@BuildStep
AnnotationsImpliesAiServiceBuildItem implyAiService() {
return new AnnotationsImpliesAiServiceBuildItem(
List.of(LangChain4jDotNames.SYSTEM_MESSAGE, LangChain4jDotNames.USER_MESSAGE,
LangChain4jDotNames.MODERATE));
}

@BuildStep
@Record(ExecutionTime.STATIC_INIT)
public void handleAiServices(
Expand All @@ -1167,7 +1174,10 @@ public void handleAiServices(
Optional<MetricsCapabilityBuildItem> metricsCapability,
Capabilities capabilities,
List<ToolMethodBuildItem> tools,
List<ToolQualifierProvider.BuildItem> toolQualifierProviderItems) {
List<ToolQualifierProvider.BuildItem> toolQualifierProviderItems,
List<AnnotationsImpliesAiServiceBuildItem> annotationsImpliesAiServiceItems,
List<SkipOutputFormatInstructionsBuildItem> skipOutputFormatInstructionsItems,
List<FallbackToDummyUserMessageBuildItem> fallbackToDummyUserMessageItems) {

IndexView index = indexBuildItem.getIndex();

Expand Down Expand Up @@ -1211,7 +1221,8 @@ public void handleAiServices(

Set<String> detectedForCreate = new HashSet<>(nameToUsed.keySet());
addCreatedAware(index, detectedForCreate);
addIfacesWithMessageAnns(index, detectedForCreate);
addIfacesWithMessageAnns(index, annotationsImpliesAiServiceItems.stream()
.flatMap(bi -> bi.getAnnotationNames().stream()).collect(Collectors.toList()), detectedForCreate);
Set<String> registeredAiServiceClassNames = declarativeAiServiceItems.stream()
.map(bi -> bi.getServiceClassInfo().name().toString()).collect(
Collectors.toUnmodifiableSet());
Expand Down Expand Up @@ -1342,7 +1353,13 @@ public void handleAiServices(
config.responseSchema(),
allowedPredicates,
ignoredPredicates,
tools, toolQualifierProviderItems);
tools, toolQualifierProviderItems,
skipOutputFormatInstructionsItems.stream().map(
SkipOutputFormatInstructionsBuildItem::getPredicate)
.reduce(mi -> false, Predicate::or),
fallbackToDummyUserMessageItems.stream().map(
FallbackToDummyUserMessageBuildItem::getPredicate)
.reduce(mi -> false, Predicate::or));
if (!methodCreateInfo.getToolClassInfo().isEmpty()) {
if ((matchingBI != null)
&& matchingBI.getChatMemoryProviderSupplierClassDotName() == null) {
Expand Down Expand Up @@ -1482,9 +1499,7 @@ private String createMethodId(MethodInfo methodInfo) {
+ Arrays.toString(methodInfo.parameters().stream().map(mp -> mp.type().name().toString()).toArray()) + ')';
}

private void addIfacesWithMessageAnns(IndexView index, Set<String> detectedForCreate) {
List<DotName> annotations = List.of(LangChain4jDotNames.SYSTEM_MESSAGE, LangChain4jDotNames.USER_MESSAGE,
LangChain4jDotNames.MODERATE);
private void addIfacesWithMessageAnns(IndexView index, List<DotName> annotations, Set<String> detectedForCreate) {
for (DotName annotation : annotations) {
Collection<AnnotationInstance> instances = index.getAnnotations(annotation);
for (AnnotationInstance instance : instances) {
Expand Down Expand Up @@ -1522,7 +1537,9 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(
Collection<Predicate<AnnotationInstance>> allowedPredicates,
Collection<Predicate<AnnotationInstance>> ignoredPredicates,
List<ToolMethodBuildItem> tools,
List<ToolQualifierProvider.BuildItem> toolQualifierProviders) {
List<ToolQualifierProvider.BuildItem> toolQualifierProviders,
Predicate<MethodInfo> skipOutputFormatInstructionsPredicate,
Predicate<MethodInfo> fallbackToDummyUserMessagePredicate) {
validateReturnType(method);

boolean requiresModeration = method.hasAnnotation(LangChain4jDotNames.MODERATE);
Expand All @@ -1532,15 +1549,17 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(

// TODO give user ability to provide custom OutputParser
String outputFormatInstructions = "";
Optional<JsonSchema> structuredOutputSchema = Optional.empty();
if (!returnType.equals(Multi.class)) {
outputFormatInstructions = SERVICE_OUTPUT_PARSER.outputFormatInstructions(returnType);
if (!skipOutputFormatInstructionsPredicate.test(method)) {
Optional<JsonSchema> structuredOutputSchema = Optional.empty();
if (!returnType.equals(Multi.class)) {
outputFormatInstructions = SERVICE_OUTPUT_PARSER.outputFormatInstructions(returnType);
}
}

List<TemplateParameterInfo> templateParams = gatherTemplateParamInfo(params, allowedPredicates, ignoredPredicates);
Optional<AiServiceMethodCreateInfo.TemplateInfo> systemMessageInfo = gatherSystemMessageInfo(method, templateParams);
AiServiceMethodCreateInfo.UserMessageInfo userMessageInfo = gatherUserMessageInfo(method, templateParams,
systemMessageInfo);
systemMessageInfo, fallbackToDummyUserMessagePredicate);

AiServiceMethodCreateInfo.ResponseSchemaInfo responseSchemaInfo = ResponseSchemaInfo.of(generateResponseSchema,
systemMessageInfo,
Expand Down Expand Up @@ -1797,7 +1816,8 @@ private Optional<Integer> gatherOverrideChatModelParameterPosition(MethodInfo me

private AiServiceMethodCreateInfo.UserMessageInfo gatherUserMessageInfo(MethodInfo method,
List<TemplateParameterInfo> templateParams,
Optional<AiServiceMethodCreateInfo.TemplateInfo> systemMessageInfo) {
Optional<AiServiceMethodCreateInfo.TemplateInfo> systemMessageInfo,
Predicate<MethodInfo> fallbackToDummyUserMesage) {

Optional<Integer> userNameParamPosition = method.annotations(LangChain4jDotNames.USER_NAME).stream().filter(
IS_METHOD_PARAMETER_ANNOTATION).map(METHOD_PARAMETER_POSITION_FUNCTION).findFirst();
Expand Down Expand Up @@ -1874,6 +1894,14 @@ private AiServiceMethodCreateInfo.UserMessageInfo gatherUserMessageInfo(MethodIn
return AiServiceMethodCreateInfo.UserMessageInfo.fromMethodParam(0, userNameParamPosition,
imageParamPosition, audioParamPosition, pdfParamPosition);
}

if (fallbackToDummyUserMesage.test(method)) {
return AiServiceMethodCreateInfo.UserMessageInfo.fromTemplate(
AiServiceMethodCreateInfo.TemplateInfo.fromText("", Map.of()), Optional.empty(),
Optional.empty(),
Optional.empty(), Optional.empty());
}

throw illegalConfigurationForMethod(
"For methods with multiple parameters, each parameter must be annotated with @V (or match an template parameter by name), @UserMessage, @UserName or @MemoryId",
method);
Expand Down Expand Up @@ -2128,8 +2156,7 @@ private List<String> gatherMethodToolClassNames(MethodInfo method) {
}

private List<String> gatherMethodMcpClientNames(MethodInfo method) {
// Using the class name to keep the McpToolBox annotation in the mcp module
AnnotationInstance mcpToolBoxInstance = method.declaredAnnotation("io.quarkiverse.langchain4j.mcp.runtime.McpToolBox");
AnnotationInstance mcpToolBoxInstance = method.declaredAnnotation(DotNames.MCP_TOOLBOX);
if (mcpToolBoxInstance == null) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package io.quarkiverse.langchain4j.deployment;

import java.util.List;

import org.jboss.jandex.DotName;

import io.quarkus.builder.item.MultiBuildItem;

/**
* A build item that can be used in order to make an interface an AiService automatically
*/
public final class AnnotationsImpliesAiServiceBuildItem extends MultiBuildItem {

private final List<DotName> annotationNames;

public AnnotationsImpliesAiServiceBuildItem(List<DotName> annotationNames) {
this.annotationNames = annotationNames;
}

public List<DotName> getAnnotationNames() {
return annotationNames;
}
}
Loading
Loading