-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add Bedrock/Anthropic prompt caching #2560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -520,8 +520,20 @@ def format(self) -> str: | |
__repr__ = _utils.dataclasses_no_defaults_repr | ||
|
||
|
||
@dataclass | ||
class CachePoint: | ||
"""A cache point marker for prompt caching. | ||
|
||
Can be inserted into UserPromptPart.content to mark cache boundaries. | ||
Models that don't support caching will filter these out. | ||
""" | ||
|
||
kind: Literal['cache-point'] = 'cache-point' | ||
"""Type identifier, this is available on all parts as a discriminator.""" | ||
|
||
|
||
MultiModalContent = ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent | ||
UserContent: TypeAlias = str | MultiModalContent | ||
UserContent: TypeAlias = str | MultiModalContent | CachePoint | ||
|
||
|
||
@dataclass(repr=False) | ||
|
@@ -637,6 +649,8 @@ def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_me | |
if settings.include_content and settings.include_binary_content: | ||
converted_part['content'] = base64.b64encode(part.data).decode() | ||
parts.append(converted_part) | ||
elif isinstance(part, CachePoint): | ||
parts.append(_otel_messages.CachePointPart(type=part.kind)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See above, I think we should skip it instead -- unless you've confirmed that the OTel spec does have cache points. |
||
else: | ||
parts.append({'type': part.kind}) # pragma: no cover | ||
return parts | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
BinaryContent, | ||
BuiltinToolCallPart, | ||
BuiltinToolReturnPart, | ||
CachePoint, | ||
DocumentUrl, | ||
ImageUrl, | ||
ModelMessage, | ||
|
@@ -47,6 +48,7 @@ | |
from anthropic.types.beta import ( | ||
BetaBase64PDFBlockParam, | ||
BetaBase64PDFSourceParam, | ||
BetaCacheControlEphemeralParam, | ||
BetaCitationsDelta, | ||
BetaCodeExecutionTool20250522Param, | ||
BetaCodeExecutionToolResultBlock, | ||
|
@@ -107,6 +109,16 @@ | |
See [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/models) for a full list. | ||
""" | ||
|
||
CacheableContentBlockParam = ( | ||
BetaTextBlockParam | ||
| BetaToolUseBlockParam | ||
| BetaServerToolUseBlockParam | ||
| BetaImageBlockParam | ||
| BetaToolResultBlockParam | ||
) | ||
"""Content block parameter types that support cache_control.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please link to the doc this came from? |
||
CACHEABLE_CONTENT_BLOCK_PARAM_TYPES = {'text', 'tool_use', 'server_tool_use', 'image', 'tool_result'} | ||
|
||
|
||
class AnthropicModelSettings(ModelSettings, total=False): | ||
"""Settings used for an Anthropic model request.""" | ||
|
@@ -382,6 +394,19 @@ def _get_builtin_tools( | |
) | ||
return tools, extra_headers | ||
|
||
@staticmethod | ||
def _add_cache_control_to_last_param(params: list[BetaContentBlockParam]) -> None: | ||
if not params: | ||
raise UserError( | ||
'CachePoint cannot be the first content in a user message - there must be previous content to attach the CachePoint to.' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Copying in context from https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#what-can-be-cached:
Annoying as it will be to implement, I think we can support |
||
) | ||
|
||
if params[-1]['type'] not in CACHEABLE_CONTENT_BLOCK_PARAM_TYPES: | ||
raise UserError(f'Cache control not supported for param type: {params[-1]["type"]}') | ||
|
||
cacheable_param = cast(CacheableContentBlockParam, params[-1]) | ||
cacheable_param['cache_control'] = BetaCacheControlEphemeralParam(type='ephemeral') | ||
|
||
async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[BetaMessageParam]]: # noqa: C901 | ||
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`.""" | ||
system_prompt_parts: list[str] = [] | ||
|
@@ -394,7 +419,10 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Be | |
system_prompt_parts.append(request_part.content) | ||
elif isinstance(request_part, UserPromptPart): | ||
async for content in self._map_user_prompt(request_part): | ||
user_content_params.append(content) | ||
if isinstance(content, CachePoint): | ||
self._add_cache_control_to_last_param(user_content_params) | ||
else: | ||
user_content_params.append(content) | ||
elif isinstance(request_part, ToolReturnPart): | ||
tool_result_block_param = BetaToolResultBlockParam( | ||
tool_use_id=_guard_tool_call_id(t=request_part), | ||
|
@@ -483,7 +511,7 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Be | |
@staticmethod | ||
async def _map_user_prompt( | ||
part: UserPromptPart, | ||
) -> AsyncGenerator[BetaContentBlockParam]: | ||
) -> AsyncGenerator[BetaContentBlockParam | CachePoint]: | ||
if isinstance(part.content, str): | ||
if part.content: # Only yield non-empty text | ||
yield BetaTextBlockParam(text=part.content, type='text') | ||
|
@@ -524,6 +552,8 @@ async def _map_user_prompt( | |
) | ||
else: # pragma: no cover | ||
raise RuntimeError(f'Unsupported media type: {item.media_type}') | ||
elif isinstance(item, CachePoint): | ||
yield item | ||
else: | ||
raise RuntimeError(f'Unsupported content type: {type(item)}') # pragma: no cover | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ | |
BinaryContent, | ||
BuiltinToolCallPart, | ||
BuiltinToolReturnPart, | ||
CachePoint, | ||
DocumentUrl, | ||
ImageUrl, | ||
ModelMessage, | ||
|
@@ -296,9 +297,14 @@ async def _process_response(self, response: ConverseResponseTypeDef) -> ModelRes | |
tool_call_id=tool_use['toolUseId'], | ||
), | ||
) | ||
cache_read_tokens = response['usage'].get('cacheReadInputTokens', 0) | ||
cache_write_tokens = response['usage'].get('cacheWriteInputTokens', 0) | ||
input_tokens = response['usage']['inputTokens'] + cache_read_tokens + cache_write_tokens | ||
u = usage.RequestUsage( | ||
DouweM marked this conversation as resolved.
Show resolved
Hide resolved
|
||
input_tokens=response['usage']['inputTokens'], | ||
input_tokens=input_tokens, | ||
output_tokens=response['usage']['outputTokens'], | ||
cache_read_tokens=cache_read_tokens, | ||
cache_write_tokens=cache_write_tokens, | ||
) | ||
response_id = response.get('ResponseMetadata', {}).get('RequestId', None) | ||
return ModelResponse( | ||
|
@@ -346,7 +352,12 @@ async def _messages_create( | |
'inferenceConfig': inference_config, | ||
} | ||
|
||
tool_config = self._map_tool_config(model_request_parameters) | ||
tool_config = self._map_tool_config( | ||
model_request_parameters, | ||
should_add_cache_point=( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is now always adding a cache point, independent of |
||
not system_prompt and BedrockModelProfile.from_profile(self.profile).bedrock_supports_prompt_caching | ||
), | ||
) | ||
if tool_config: | ||
params['toolConfig'] = tool_config | ||
|
||
|
@@ -395,11 +406,16 @@ def _map_inference_config( | |
|
||
return inference_config | ||
|
||
def _map_tool_config(self, model_request_parameters: ModelRequestParameters) -> ToolConfigurationTypeDef | None: | ||
def _map_tool_config( | ||
self, model_request_parameters: ModelRequestParameters, should_add_cache_point: bool = False | ||
) -> ToolConfigurationTypeDef | None: | ||
tools = self._get_tools(model_request_parameters) | ||
if not tools: | ||
return None | ||
|
||
if should_add_cache_point: | ||
tools[-1]['cachePoint'] = {'type': 'default'} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is causing
It sounds like the |
||
|
||
tool_choice: ToolChoiceTypeDef | ||
if not model_request_parameters.allow_text_output: | ||
tool_choice = {'any': {}} | ||
|
@@ -429,7 +445,12 @@ async def _map_messages( # noqa: C901 | |
if isinstance(part, SystemPromptPart) and part.content: | ||
system_prompt.append({'text': part.content}) | ||
elif isinstance(part, UserPromptPart): | ||
bedrock_messages.extend(await self._map_user_prompt(part, document_count)) | ||
has_leading_cache_point, user_messages = await self._map_user_prompt( | ||
part, document_count, profile.bedrock_supports_prompt_caching | ||
) | ||
if has_leading_cache_point: | ||
system_prompt.append({'cachePoint': {'type': 'default'}}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should only add it to the system prompt if this is the first |
||
bedrock_messages.extend(user_messages) | ||
elif isinstance(part, ToolReturnPart): | ||
assert part.tool_call_id is not None | ||
bedrock_messages.append( | ||
|
@@ -522,13 +543,22 @@ async def _map_messages( # noqa: C901 | |
|
||
return system_prompt, processed_messages | ||
|
||
@staticmethod | ||
async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int]) -> list[MessageUnionTypeDef]: | ||
async def _map_user_prompt( # noqa: C901 | ||
self, part: UserPromptPart, document_count: Iterator[int], supports_caching: bool | ||
) -> tuple[bool, list[MessageUnionTypeDef]]: | ||
content: list[ContentBlockUnionTypeDef] = [] | ||
has_leading_cache_point = False | ||
|
||
if isinstance(part.content, str): | ||
content.append({'text': part.content}) | ||
else: | ||
for item in part.content: | ||
if part.content and isinstance(part.content[0], CachePoint): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can move this to the |
||
has_leading_cache_point = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we sure that bedrock needs this special behavior of moving this to the system prompt? It doesn't support a |
||
items_to_process = part.content[1:] | ||
else: | ||
items_to_process = part.content | ||
|
||
for item in items_to_process: | ||
if isinstance(item, str): | ||
content.append({'text': item}) | ||
elif isinstance(item, BinaryContent): | ||
|
@@ -578,11 +608,15 @@ async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int]) | |
), f'Unsupported video format: {format}' | ||
video: VideoBlockTypeDef = {'format': format, 'source': {'bytes': downloaded_item['data']}} | ||
content.append({'video': video}) | ||
elif isinstance(item, CachePoint): | ||
if supports_caching: | ||
content.append({'cachePoint': {'type': 'default'}}) | ||
continue | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need this |
||
elif isinstance(item, AudioUrl): # pragma: no cover | ||
raise NotImplementedError('Audio is not supported yet.') | ||
else: | ||
assert_never(item) | ||
return [{'role': 'user', 'content': content}] | ||
return has_leading_cache_point, [{'role': 'user', 'content': content}] | ||
|
||
@staticmethod | ||
def _map_tool_call(t: ToolCallPart) -> ContentBlockOutputTypeDef: | ||
|
@@ -674,9 +708,14 @@ def timestamp(self) -> datetime: | |
return self._timestamp | ||
|
||
def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> usage.RequestUsage: | ||
cache_read_tokens = metadata['usage'].get('cacheReadInputTokens', 0) | ||
cache_write_tokens = metadata['usage'].get('cacheWriteInputTokens', 0) | ||
input_tokens = metadata['usage']['inputTokens'] + cache_read_tokens + cache_write_tokens | ||
return usage.RequestUsage( | ||
input_tokens=metadata['usage']['inputTokens'], | ||
input_tokens=input_tokens, | ||
output_tokens=metadata['usage']['outputTokens'], | ||
cache_write_tokens=cache_write_tokens, | ||
cache_read_tokens=cache_read_tokens, | ||
) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need this here, as it mentions at the top of the file, these are supposed to match the OpenTelemetry GenAI spec, which I don't think has cache points.