Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions core/src/main/java/com/google/adk/tools/mcp/McpSessionManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,27 @@
// TODO(b/413489523): Implement this class.
public class McpSessionManager {

private final Object connectionParams; // ServerParameters or SseServerParameters
private static final Logger logger = LoggerFactory.getLogger(McpSessionManager.class);
private final McpClientTransport transport;

public McpSessionManager(Object connectionParams) {
this.connectionParams = connectionParams;
this.transport = createTransport(connectionParams);
}

public McpSessionManager(McpClientTransport transport) {
this.transport = transport;
}

public McpSyncClient createSession() {
return initializeSession(this.connectionParams);
return initializeSession(this.transport);
}

public static McpSyncClient initializeSession(Object connectionParams) {
McpClientTransport transport;
if (connectionParams instanceof ServerParameters serverParameters) {
transport = new StdioClientTransport(serverParameters);
} else if (connectionParams instanceof SseServerParameters sseServerParams) {
transport =
HttpClientSseClientTransport.builder(sseServerParams.url()).sseEndpoint("sse").build();
} else {
throw new IllegalArgumentException(
"Connection parameters must be either ServerParameters or SseServerParameters, but got "
+ connectionParams.getClass().getName());
}
McpClientTransport transport = createTransport(connectionParams);
return initializeSession(transport);
}

public static McpSyncClient initializeSession(McpClientTransport transport) {
McpSyncClient client =
McpClient.sync(transport)
.requestTimeout(Duration.ofSeconds(10))
Expand All @@ -70,4 +68,16 @@ public static McpSyncClient initializeSession(Object connectionParams) {

return client;
}

private static McpClientTransport createTransport(Object connectionParams) {
if (connectionParams instanceof ServerParameters serverParameters) {
return new StdioClientTransport(serverParameters);
} else if (connectionParams instanceof SseServerParameters sseServerParams) {
return HttpClientSseClientTransport.builder(sseServerParams.url()).sseEndpoint("sse").build();
} else {
throw new IllegalArgumentException(
"Connection parameters must be either ServerParameters or SseServerParameters, but got "
+ connectionParams.getClass().getName());
}
}
}
30 changes: 27 additions & 3 deletions core/src/main/java/com/google/adk/tools/mcp/McpToolset.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.google.adk.JsonBaseModel;
import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.client.transport.ServerParameters;
import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpSchema.ListToolsResult;
import java.util.List;
import java.util.Objects;
Expand Down Expand Up @@ -49,15 +50,38 @@ public class McpToolset implements AutoCloseable {
private McpSyncClient mcpSession;
private final ObjectMapper objectMapper;

/**
* Initializes the McpToolset with a custom McpClientTransport.
*
* @param transport The custom MCP client transport to use.
* @param objectMapper An ObjectMapper instance for parsing schemas.
*/
public McpToolset(McpClientTransport transport, ObjectMapper objectMapper) {
Objects.requireNonNull(transport, "transport cannot be null");
Objects.requireNonNull(objectMapper, "objectMapper cannot be null");
this.objectMapper = objectMapper;
this.mcpSessionManager = new McpSessionManager(transport); // Use the new constructor
}

/**
* Initializes the McpToolset with a custom McpClientTransport, using the default ADK
* ObjectMapper.
*
* @param transport The custom MCP client transport to use.
*/
public McpToolset(McpClientTransport transport) {
this(transport, JsonBaseModel.getMapper());
}

/**
* Initializes the McpToolset with SSE server parameters.
*
* @param connectionParams The SSE connection parameters to the MCP server.
* @param objectMapper An ObjectMapper instance for parsing schemas.
*/
public McpToolset(SseServerParameters connectionParams, ObjectMapper objectMapper) {
Objects.requireNonNull(connectionParams);
Objects.requireNonNull(objectMapper);
Objects.requireNonNull(connectionParams, "connectionParams cannot be null");
Objects.requireNonNull(objectMapper, "objectMapper cannot be null");
this.objectMapper = objectMapper;
this.mcpSessionManager = new McpSessionManager(connectionParams);
}
Expand All @@ -69,7 +93,7 @@ public McpToolset(SseServerParameters connectionParams, ObjectMapper objectMappe
* @param objectMapper An ObjectMapper instance for parsing schemas.
*/
public McpToolset(ServerParameters connectionParams, ObjectMapper objectMapper) {
Objects.requireNonNull(connectionParams);
Objects.requireNonNull(connectionParams, "connectionParams cannot be null");
this.objectMapper = objectMapper;
this.mcpSessionManager = new McpSessionManager(connectionParams);
}
Expand Down