diff --git a/core/src/main/java/com/google/adk/tools/mcp/McpSessionManager.java b/core/src/main/java/com/google/adk/tools/mcp/McpSessionManager.java index 4f6145d5..f430da65 100644 --- a/core/src/main/java/com/google/adk/tools/mcp/McpSessionManager.java +++ b/core/src/main/java/com/google/adk/tools/mcp/McpSessionManager.java @@ -37,29 +37,27 @@ // TODO(b/413489523): Implement this class. public class McpSessionManager { - private final Object connectionParams; // ServerParameters or SseServerParameters private static final Logger logger = LoggerFactory.getLogger(McpSessionManager.class); + private final McpClientTransport transport; public McpSessionManager(Object connectionParams) { - this.connectionParams = connectionParams; + this.transport = createTransport(connectionParams); + } + + public McpSessionManager(McpClientTransport transport) { + this.transport = transport; } public McpSyncClient createSession() { - return initializeSession(this.connectionParams); + return initializeSession(this.transport); } public static McpSyncClient initializeSession(Object connectionParams) { - McpClientTransport transport; - if (connectionParams instanceof ServerParameters serverParameters) { - transport = new StdioClientTransport(serverParameters); - } else if (connectionParams instanceof SseServerParameters sseServerParams) { - transport = - HttpClientSseClientTransport.builder(sseServerParams.url()).sseEndpoint("sse").build(); - } else { - throw new IllegalArgumentException( - "Connection parameters must be either ServerParameters or SseServerParameters, but got " - + connectionParams.getClass().getName()); - } + McpClientTransport transport = createTransport(connectionParams); + return initializeSession(transport); + } + + public static McpSyncClient initializeSession(McpClientTransport transport) { McpSyncClient client = McpClient.sync(transport) .requestTimeout(Duration.ofSeconds(10)) @@ -70,4 +68,16 @@ public static McpSyncClient initializeSession(Object connectionParams) { return client; } + + private static McpClientTransport createTransport(Object connectionParams) { + if (connectionParams instanceof ServerParameters serverParameters) { + return new StdioClientTransport(serverParameters); + } else if (connectionParams instanceof SseServerParameters sseServerParams) { + return HttpClientSseClientTransport.builder(sseServerParams.url()).sseEndpoint("sse").build(); + } else { + throw new IllegalArgumentException( + "Connection parameters must be either ServerParameters or SseServerParameters, but got " + + connectionParams.getClass().getName()); + } + } } diff --git a/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java b/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java index f512deb3..be158031 100644 --- a/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java +++ b/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java @@ -22,6 +22,7 @@ import com.google.adk.JsonBaseModel; import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.client.transport.ServerParameters; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema.ListToolsResult; import java.util.List; import java.util.Objects; @@ -49,6 +50,29 @@ public class McpToolset implements AutoCloseable { private McpSyncClient mcpSession; private final ObjectMapper objectMapper; + /** + * Initializes the McpToolset with a custom McpClientTransport. + * + * @param transport The custom MCP client transport to use. + * @param objectMapper An ObjectMapper instance for parsing schemas. + */ + public McpToolset(McpClientTransport transport, ObjectMapper objectMapper) { + Objects.requireNonNull(transport, "transport cannot be null"); + Objects.requireNonNull(objectMapper, "objectMapper cannot be null"); + this.objectMapper = objectMapper; + this.mcpSessionManager = new McpSessionManager(transport); // Use the new constructor + } + + /** + * Initializes the McpToolset with a custom McpClientTransport, using the default ADK + * ObjectMapper. + * + * @param transport The custom MCP client transport to use. + */ + public McpToolset(McpClientTransport transport) { + this(transport, JsonBaseModel.getMapper()); + } + /** * Initializes the McpToolset with SSE server parameters. * @@ -56,8 +80,8 @@ public class McpToolset implements AutoCloseable { * @param objectMapper An ObjectMapper instance for parsing schemas. */ public McpToolset(SseServerParameters connectionParams, ObjectMapper objectMapper) { - Objects.requireNonNull(connectionParams); - Objects.requireNonNull(objectMapper); + Objects.requireNonNull(connectionParams, "connectionParams cannot be null"); + Objects.requireNonNull(objectMapper, "objectMapper cannot be null"); this.objectMapper = objectMapper; this.mcpSessionManager = new McpSessionManager(connectionParams); } @@ -69,7 +93,7 @@ public McpToolset(SseServerParameters connectionParams, ObjectMapper objectMappe * @param objectMapper An ObjectMapper instance for parsing schemas. */ public McpToolset(ServerParameters connectionParams, ObjectMapper objectMapper) { - Objects.requireNonNull(connectionParams); + Objects.requireNonNull(connectionParams, "connectionParams cannot be null"); this.objectMapper = objectMapper; this.mcpSessionManager = new McpSessionManager(connectionParams); }