Skip to content

Commit a2222e1

Browse files
committed
Add Bedrock/Anthropic prompt caching
1 parent 3c2f1cf commit a2222e1

File tree

12 files changed

+645
-35
lines changed

12 files changed

+645
-35
lines changed

examples/pydantic_ai_examples/anthropic_prompt_caching.py

Lines changed: 481 additions & 0 deletions
Large diffs are not rendered by default.

pydantic_ai_slim/pydantic_ai/_otel_messages.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,13 @@ class ThinkingPart(TypedDict):
4646
content: NotRequired[str]
4747

4848

49-
MessagePart: TypeAlias = 'TextPart | ToolCallPart | ToolCallResponsePart | MediaUrlPart | BinaryDataPart | ThinkingPart'
49+
class CachePointPart(TypedDict):
50+
type: Literal['cache-point']
51+
52+
53+
MessagePart: TypeAlias = (
54+
'TextPart | ToolCallPart | ToolCallResponsePart | MediaUrlPart | BinaryDataPart | ThinkingPart | CachePointPart'
55+
)
5056

5157

5258
Role = Literal['system', 'user', 'assistant']

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,8 +520,20 @@ def format(self) -> str:
520520
__repr__ = _utils.dataclasses_no_defaults_repr
521521

522522

523+
@dataclass
524+
class CachePoint:
525+
"""A cache point marker for prompt caching.
526+
527+
Can be inserted into UserPromptPart.content to mark cache boundaries.
528+
Models that don't support caching will filter these out.
529+
"""
530+
531+
kind: Literal['cache-point'] = 'cache-point'
532+
"""Type identifier, this is available on all parts as a discriminator."""
533+
534+
523535
MultiModalContent = ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent
524-
UserContent: TypeAlias = str | MultiModalContent
536+
UserContent: TypeAlias = str | MultiModalContent | CachePoint
525537

526538

527539
@dataclass(repr=False)
@@ -637,6 +649,8 @@ def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_me
637649
if settings.include_content and settings.include_binary_content:
638650
converted_part['content'] = base64.b64encode(part.data).decode()
639651
parts.append(converted_part)
652+
elif isinstance(part, CachePoint):
653+
parts.append(_otel_messages.CachePointPart(type=part.kind))
640654
else:
641655
parts.append({'type': part.kind}) # pragma: no cover
642656
return parts

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
BinaryContent,
2121
BuiltinToolCallPart,
2222
BuiltinToolReturnPart,
23+
CachePoint,
2324
DocumentUrl,
2425
ImageUrl,
2526
ModelMessage,
@@ -47,6 +48,7 @@
4748
from anthropic.types.beta import (
4849
BetaBase64PDFBlockParam,
4950
BetaBase64PDFSourceParam,
51+
BetaCacheControlEphemeralParam,
5052
BetaCitationsDelta,
5153
BetaCodeExecutionTool20250522Param,
5254
BetaCodeExecutionToolResultBlock,
@@ -107,6 +109,16 @@
107109
See [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/models) for a full list.
108110
"""
109111

112+
CacheableContentBlockParam = (
113+
BetaTextBlockParam
114+
| BetaToolUseBlockParam
115+
| BetaServerToolUseBlockParam
116+
| BetaImageBlockParam
117+
| BetaToolResultBlockParam
118+
)
119+
"""Content block parameter types that support cache_control."""
120+
CACHEABLE_CONTENT_BLOCK_PARAM_TYPES = {'text', 'tool_use', 'server_tool_use', 'image', 'tool_result'}
121+
110122

111123
class AnthropicModelSettings(ModelSettings, total=False):
112124
"""Settings used for an Anthropic model request."""
@@ -382,6 +394,19 @@ def _get_builtin_tools(
382394
)
383395
return tools, extra_headers
384396

397+
@staticmethod
398+
def _add_cache_control_to_last_param(params: list[BetaContentBlockParam]) -> None:
399+
if not params:
400+
raise UserError(
401+
'CachePoint cannot be the first content in a user message - there must be previous content to attach the CachePoint to.'
402+
)
403+
404+
if params[-1]['type'] not in CACHEABLE_CONTENT_BLOCK_PARAM_TYPES:
405+
raise UserError(f'Cache control not supported for param type: {params[-1]["type"]}')
406+
407+
cacheable_param = cast(CacheableContentBlockParam, params[-1])
408+
cacheable_param['cache_control'] = BetaCacheControlEphemeralParam(type='ephemeral')
409+
385410
async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[BetaMessageParam]]: # noqa: C901
386411
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
387412
system_prompt_parts: list[str] = []
@@ -394,7 +419,10 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Be
394419
system_prompt_parts.append(request_part.content)
395420
elif isinstance(request_part, UserPromptPart):
396421
async for content in self._map_user_prompt(request_part):
397-
user_content_params.append(content)
422+
if isinstance(content, CachePoint):
423+
self._add_cache_control_to_last_param(user_content_params)
424+
else:
425+
user_content_params.append(content)
398426
elif isinstance(request_part, ToolReturnPart):
399427
tool_result_block_param = BetaToolResultBlockParam(
400428
tool_use_id=_guard_tool_call_id(t=request_part),
@@ -483,7 +511,7 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Be
483511
@staticmethod
484512
async def _map_user_prompt(
485513
part: UserPromptPart,
486-
) -> AsyncGenerator[BetaContentBlockParam]:
514+
) -> AsyncGenerator[BetaContentBlockParam | CachePoint]:
487515
if isinstance(part.content, str):
488516
if part.content: # Only yield non-empty text
489517
yield BetaTextBlockParam(text=part.content, type='text')
@@ -524,6 +552,8 @@ async def _map_user_prompt(
524552
)
525553
else: # pragma: no cover
526554
raise RuntimeError(f'Unsupported media type: {item.media_type}')
555+
elif isinstance(item, CachePoint):
556+
yield item
527557
else:
528558
raise RuntimeError(f'Unsupported content type: {type(item)}') # pragma: no cover
529559

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
BinaryContent,
2222
BuiltinToolCallPart,
2323
BuiltinToolReturnPart,
24+
CachePoint,
2425
DocumentUrl,
2526
ImageUrl,
2627
ModelMessage,
@@ -296,9 +297,14 @@ async def _process_response(self, response: ConverseResponseTypeDef) -> ModelRes
296297
tool_call_id=tool_use['toolUseId'],
297298
),
298299
)
300+
cache_read_tokens = response['usage'].get('cacheReadInputTokens', 0)
301+
cache_write_tokens = response['usage'].get('cacheWriteInputTokens', 0)
302+
input_tokens = response['usage']['inputTokens'] + cache_read_tokens + cache_write_tokens
299303
u = usage.RequestUsage(
300-
input_tokens=response['usage']['inputTokens'],
304+
input_tokens=input_tokens,
301305
output_tokens=response['usage']['outputTokens'],
306+
cache_read_tokens=cache_read_tokens,
307+
cache_write_tokens=cache_write_tokens,
302308
)
303309
response_id = response.get('ResponseMetadata', {}).get('RequestId', None)
304310
return ModelResponse(
@@ -346,7 +352,12 @@ async def _messages_create(
346352
'inferenceConfig': inference_config,
347353
}
348354

349-
tool_config = self._map_tool_config(model_request_parameters)
355+
tool_config = self._map_tool_config(
356+
model_request_parameters,
357+
should_add_cache_point=(
358+
not system_prompt and BedrockModelProfile.from_profile(self.profile).bedrock_supports_prompt_caching
359+
),
360+
)
350361
if tool_config:
351362
params['toolConfig'] = tool_config
352363

@@ -395,11 +406,16 @@ def _map_inference_config(
395406

396407
return inference_config
397408

398-
def _map_tool_config(self, model_request_parameters: ModelRequestParameters) -> ToolConfigurationTypeDef | None:
409+
def _map_tool_config(
410+
self, model_request_parameters: ModelRequestParameters, should_add_cache_point: bool = False
411+
) -> ToolConfigurationTypeDef | None:
399412
tools = self._get_tools(model_request_parameters)
400413
if not tools:
401414
return None
402415

416+
if should_add_cache_point:
417+
tools[-1]['cachePoint'] = {'type': 'default'}
418+
403419
tool_choice: ToolChoiceTypeDef
404420
if not model_request_parameters.allow_text_output:
405421
tool_choice = {'any': {}}
@@ -429,7 +445,12 @@ async def _map_messages( # noqa: C901
429445
if isinstance(part, SystemPromptPart) and part.content:
430446
system_prompt.append({'text': part.content})
431447
elif isinstance(part, UserPromptPart):
432-
bedrock_messages.extend(await self._map_user_prompt(part, document_count))
448+
has_leading_cache_point, user_messages = await self._map_user_prompt(
449+
part, document_count, profile.bedrock_supports_prompt_caching
450+
)
451+
if has_leading_cache_point:
452+
system_prompt.append({'cachePoint': {'type': 'default'}})
453+
bedrock_messages.extend(user_messages)
433454
elif isinstance(part, ToolReturnPart):
434455
assert part.tool_call_id is not None
435456
bedrock_messages.append(
@@ -522,13 +543,22 @@ async def _map_messages( # noqa: C901
522543

523544
return system_prompt, processed_messages
524545

525-
@staticmethod
526-
async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int]) -> list[MessageUnionTypeDef]:
546+
async def _map_user_prompt( # noqa: C901
547+
self, part: UserPromptPart, document_count: Iterator[int], supports_caching: bool
548+
) -> tuple[bool, list[MessageUnionTypeDef]]:
527549
content: list[ContentBlockUnionTypeDef] = []
550+
has_leading_cache_point = False
551+
528552
if isinstance(part.content, str):
529553
content.append({'text': part.content})
530554
else:
531-
for item in part.content:
555+
if part.content and isinstance(part.content[0], CachePoint):
556+
has_leading_cache_point = True
557+
items_to_process = part.content[1:]
558+
else:
559+
items_to_process = part.content
560+
561+
for item in items_to_process:
532562
if isinstance(item, str):
533563
content.append({'text': item})
534564
elif isinstance(item, BinaryContent):
@@ -578,11 +608,15 @@ async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int])
578608
), f'Unsupported video format: {format}'
579609
video: VideoBlockTypeDef = {'format': format, 'source': {'bytes': downloaded_item['data']}}
580610
content.append({'video': video})
611+
elif isinstance(item, CachePoint):
612+
if supports_caching:
613+
content.append({'cachePoint': {'type': 'default'}})
614+
continue
581615
elif isinstance(item, AudioUrl): # pragma: no cover
582616
raise NotImplementedError('Audio is not supported yet.')
583617
else:
584618
assert_never(item)
585-
return [{'role': 'user', 'content': content}]
619+
return has_leading_cache_point, [{'role': 'user', 'content': content}]
586620

587621
@staticmethod
588622
def _map_tool_call(t: ToolCallPart) -> ContentBlockOutputTypeDef:
@@ -674,9 +708,14 @@ def timestamp(self) -> datetime:
674708
return self._timestamp
675709

676710
def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> usage.RequestUsage:
711+
cache_read_tokens = metadata['usage'].get('cacheReadInputTokens', 0)
712+
cache_write_tokens = metadata['usage'].get('cacheWriteInputTokens', 0)
713+
input_tokens = metadata['usage']['inputTokens'] + cache_read_tokens + cache_write_tokens
677714
return usage.RequestUsage(
678-
input_tokens=metadata['usage']['inputTokens'],
715+
input_tokens=input_tokens,
679716
output_tokens=metadata['usage']['outputTokens'],
717+
cache_write_tokens=cache_write_tokens,
718+
cache_read_tokens=cache_read_tokens,
680719
)
681720

682721

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
BinaryContent,
2222
BuiltinToolCallPart,
2323
BuiltinToolReturnPart,
24+
CachePoint,
2425
FileUrl,
2526
ModelMessage,
2627
ModelRequest,
@@ -371,6 +372,8 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[_GeminiPartUnion]
371372
else: # pragma: lax no cover
372373
file_data = _GeminiFileDataPart(file_data={'file_uri': item.url, 'mime_type': item.media_type})
373374
content.append(file_data)
375+
elif isinstance(item, CachePoint):
376+
raise NotImplementedError('CachePoint is not supported for Gemini')
374377
else:
375378
assert_never(item) # pragma: lax no cover
376379
return content

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
BinaryContent,
2020
BuiltinToolCallPart,
2121
BuiltinToolReturnPart,
22+
CachePoint,
2223
FileUrl,
2324
ModelMessage,
2425
ModelRequest,
@@ -514,6 +515,8 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[PartDict]:
514515
content.append(
515516
{'file_data': {'file_uri': item.url, 'mime_type': item.media_type}}
516517
) # pragma: lax no cover
518+
elif isinstance(item, CachePoint):
519+
raise NotImplementedError('CachePoint is not supported for Google')
517520
else:
518521
assert_never(item)
519522
return content

pydantic_ai_slim/pydantic_ai/models/huggingface.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
BinaryContent,
2020
BuiltinToolCallPart,
2121
BuiltinToolReturnPart,
22+
CachePoint,
2223
DocumentUrl,
2324
ImageUrl,
2425
ModelMessage,
@@ -426,6 +427,8 @@ async def _map_user_prompt(part: UserPromptPart) -> ChatCompletionInputMessage:
426427
raise NotImplementedError('DocumentUrl is not supported for Hugging Face')
427428
elif isinstance(item, VideoUrl):
428429
raise NotImplementedError('VideoUrl is not supported for Hugging Face')
430+
elif isinstance(item, CachePoint):
431+
raise NotImplementedError('CachePoint is not supported for Hugging Face')
429432
else:
430433
assert_never(item)
431434
return ChatCompletionInputMessage(role='user', content=content) # type: ignore

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
BinaryContent,
2424
BuiltinToolCallPart,
2525
BuiltinToolReturnPart,
26+
CachePoint,
2627
DocumentUrl,
2728
ImageUrl,
2829
ModelMessage,
@@ -713,6 +714,8 @@ async def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessa
713714
content.append(file)
714715
elif isinstance(item, VideoUrl): # pragma: no cover
715716
raise NotImplementedError('VideoUrl is not supported for OpenAI')
717+
elif isinstance(item, CachePoint):
718+
raise NotImplementedError('CachePoint is not supported for OpenAI')
716719
else:
717720
assert_never(item)
718721
return chat.ChatCompletionUserMessageParam(role='user', content=content)
@@ -1150,6 +1153,8 @@ async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessagePa
11501153
)
11511154
elif isinstance(item, VideoUrl): # pragma: no cover
11521155
raise NotImplementedError('VideoUrl is not supported for OpenAI.')
1156+
elif isinstance(item, CachePoint):
1157+
raise NotImplementedError('CachePoint is not supported for OpenAI')
11531158
else:
11541159
assert_never(item)
11551160
return responses.EasyInputMessageParam(role='user', content=content)

pydantic_ai_slim/pydantic_ai/providers/bedrock.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,41 @@ class BedrockModelProfile(ModelProfile):
3838
bedrock_supports_tool_choice: bool = False
3939
bedrock_tool_result_format: Literal['text', 'json'] = 'text'
4040
bedrock_send_back_thinking_parts: bool = False
41+
bedrock_supports_prompt_caching: bool = False
42+
43+
44+
# Supported models: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
45+
ANTHROPIC_CACHING_SUPPORTED_MODELS = {'claude-3-5-sonnet', 'claude-3-5-haiku', 'claude-3-7-sonnet', 'claude-sonnet-4'}
46+
47+
# Supported models: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
48+
AMAZON_CACHING_SUPPORTED_MODELS = {'nova-micro', 'nova-lite', 'nova-pro', 'nova-premier'}
49+
50+
51+
def bedrock_anthropic_model_profile(model_name: str) -> ModelProfile | None:
52+
"""Create a Bedrock model profile for Anthropic models with caching support where applicable."""
53+
return BedrockModelProfile(
54+
bedrock_supports_tool_choice=True,
55+
bedrock_send_back_thinking_parts=True,
56+
bedrock_supports_prompt_caching=any(
57+
supported in model_name for supported in ANTHROPIC_CACHING_SUPPORTED_MODELS
58+
),
59+
).update(anthropic_model_profile(model_name))
60+
61+
62+
def bedrock_mistral_model_profile(model_name: str) -> ModelProfile | None:
63+
"""Create a Bedrock model profile for Mistral models."""
64+
return BedrockModelProfile(bedrock_tool_result_format='json').update(mistral_model_profile(model_name))
4165

4266

4367
def bedrock_amazon_model_profile(model_name: str) -> ModelProfile | None:
44-
"""Get the model profile for an Amazon model used via Bedrock."""
68+
"""Get the model profile for an Amazon model used via Bedrock with caching support where applicable."""
4569
profile = amazon_model_profile(model_name)
4670
if 'nova' in model_name:
47-
return BedrockModelProfile(bedrock_supports_tool_choice=True).update(profile)
71+
# Check if this Nova model supports prompt caching
72+
supports_caching = any(supported in model_name for supported in AMAZON_CACHING_SUPPORTED_MODELS)
73+
return BedrockModelProfile(
74+
bedrock_supports_tool_choice=True, bedrock_supports_prompt_caching=supports_caching
75+
).update(profile)
4876
return profile
4977

5078

@@ -65,12 +93,8 @@ def client(self) -> BaseClient:
6593

6694
def model_profile(self, model_name: str) -> ModelProfile | None:
6795
provider_to_profile: dict[str, Callable[[str], ModelProfile | None]] = {
68-
'anthropic': lambda model_name: BedrockModelProfile(
69-
bedrock_supports_tool_choice=True, bedrock_send_back_thinking_parts=True
70-
).update(anthropic_model_profile(model_name)),
71-
'mistral': lambda model_name: BedrockModelProfile(bedrock_tool_result_format='json').update(
72-
mistral_model_profile(model_name)
73-
),
96+
'anthropic': bedrock_anthropic_model_profile,
97+
'mistral': bedrock_mistral_model_profile,
7498
'cohere': cohere_model_profile,
7599
'amazon': bedrock_amazon_model_profile,
76100
'meta': meta_model_profile,

0 commit comments

Comments
 (0)