Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
481 changes: 481 additions & 0 deletions examples/pydantic_ai_examples/anthropic_prompt_caching.py

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_otel_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ class ThinkingPart(TypedDict):
content: NotRequired[str]


MessagePart: TypeAlias = 'TextPart | ToolCallPart | ToolCallResponsePart | MediaUrlPart | BinaryDataPart | ThinkingPart'
class CachePointPart(TypedDict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this here, as it mentions at the top of the file, these are supposed to match the OpenTelemetry GenAI spec, which I don't think has cache points.

type: Literal['cache-point']


MessagePart: TypeAlias = (
'TextPart | ToolCallPart | ToolCallResponsePart | MediaUrlPart | BinaryDataPart | ThinkingPart | CachePointPart'
)


Role = Literal['system', 'user', 'assistant']
Expand Down
16 changes: 15 additions & 1 deletion pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,8 +520,20 @@ def format(self) -> str:
__repr__ = _utils.dataclasses_no_defaults_repr


@dataclass
class CachePoint:
"""A cache point marker for prompt caching.

Can be inserted into UserPromptPart.content to mark cache boundaries.
Models that don't support caching will filter these out.
"""

kind: Literal['cache-point'] = 'cache-point'
"""Type identifier, this is available on all parts as a discriminator."""


MultiModalContent = ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent
UserContent: TypeAlias = str | MultiModalContent
UserContent: TypeAlias = str | MultiModalContent | CachePoint


@dataclass(repr=False)
Expand Down Expand Up @@ -637,6 +649,8 @@ def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_me
if settings.include_content and settings.include_binary_content:
converted_part['content'] = base64.b64encode(part.data).decode()
parts.append(converted_part)
elif isinstance(part, CachePoint):
parts.append(_otel_messages.CachePointPart(type=part.kind))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above, I think we should skip it instead -- unless you've confirmed that the OTel spec does have cache points.

else:
parts.append({'type': part.kind}) # pragma: no cover
return parts
Expand Down
34 changes: 32 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
BinaryContent,
BuiltinToolCallPart,
BuiltinToolReturnPart,
CachePoint,
DocumentUrl,
ImageUrl,
ModelMessage,
Expand Down Expand Up @@ -47,6 +48,7 @@
from anthropic.types.beta import (
BetaBase64PDFBlockParam,
BetaBase64PDFSourceParam,
BetaCacheControlEphemeralParam,
BetaCitationsDelta,
BetaCodeExecutionTool20250522Param,
BetaCodeExecutionToolResultBlock,
Expand Down Expand Up @@ -107,6 +109,16 @@
See [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/models) for a full list.
"""

CacheableContentBlockParam = (
BetaTextBlockParam
| BetaToolUseBlockParam
| BetaServerToolUseBlockParam
| BetaImageBlockParam
| BetaToolResultBlockParam
)
"""Content block parameter types that support cache_control."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please link to the doc this came from?

CACHEABLE_CONTENT_BLOCK_PARAM_TYPES = {'text', 'tool_use', 'server_tool_use', 'image', 'tool_result'}


class AnthropicModelSettings(ModelSettings, total=False):
"""Settings used for an Anthropic model request."""
Expand Down Expand Up @@ -382,6 +394,19 @@ def _get_builtin_tools(
)
return tools, extra_headers

@staticmethod
def _add_cache_control_to_last_param(params: list[BetaContentBlockParam]) -> None:
if not params:
raise UserError(
'CachePoint cannot be the first content in a user message - there must be previous content to attach the CachePoint to.'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copying in context from https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#what-can-be-cached:

Tools: Tool definitions in the tools array
System messages: Content blocks in the system array
Text messages: Content blocks in the messages.content array, for both user and assistant turns
Images & Documents: Content blocks in the messages.content array, in user turns
Tool use and tool results: Content blocks in the messages.content array, in both user and assistant turns

Annoying as it will be to implement, I think we can support CachePoint as the first content in a user message, by adding it to whatever came before it: the system message, tool definition, or the last message of the assistant output. Similar to what we're doing for Bedrock.

)

if params[-1]['type'] not in CACHEABLE_CONTENT_BLOCK_PARAM_TYPES:
raise UserError(f'Cache control not supported for param type: {params[-1]["type"]}')

cacheable_param = cast(CacheableContentBlockParam, params[-1])
cacheable_param['cache_control'] = BetaCacheControlEphemeralParam(type='ephemeral')

async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[BetaMessageParam]]: # noqa: C901
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
system_prompt_parts: list[str] = []
Expand All @@ -394,7 +419,10 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Be
system_prompt_parts.append(request_part.content)
elif isinstance(request_part, UserPromptPart):
async for content in self._map_user_prompt(request_part):
user_content_params.append(content)
if isinstance(content, CachePoint):
self._add_cache_control_to_last_param(user_content_params)
else:
user_content_params.append(content)
elif isinstance(request_part, ToolReturnPart):
tool_result_block_param = BetaToolResultBlockParam(
tool_use_id=_guard_tool_call_id(t=request_part),
Expand Down Expand Up @@ -483,7 +511,7 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Be
@staticmethod
async def _map_user_prompt(
part: UserPromptPart,
) -> AsyncGenerator[BetaContentBlockParam]:
) -> AsyncGenerator[BetaContentBlockParam | CachePoint]:
if isinstance(part.content, str):
if part.content: # Only yield non-empty text
yield BetaTextBlockParam(text=part.content, type='text')
Expand Down Expand Up @@ -524,6 +552,8 @@ async def _map_user_prompt(
)
else: # pragma: no cover
raise RuntimeError(f'Unsupported media type: {item.media_type}')
elif isinstance(item, CachePoint):
yield item
else:
raise RuntimeError(f'Unsupported content type: {type(item)}') # pragma: no cover

Expand Down
57 changes: 48 additions & 9 deletions pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
BinaryContent,
BuiltinToolCallPart,
BuiltinToolReturnPart,
CachePoint,
DocumentUrl,
ImageUrl,
ModelMessage,
Expand Down Expand Up @@ -296,9 +297,14 @@ async def _process_response(self, response: ConverseResponseTypeDef) -> ModelRes
tool_call_id=tool_use['toolUseId'],
),
)
cache_read_tokens = response['usage'].get('cacheReadInputTokens', 0)
cache_write_tokens = response['usage'].get('cacheWriteInputTokens', 0)
input_tokens = response['usage']['inputTokens'] + cache_read_tokens + cache_write_tokens
u = usage.RequestUsage(
input_tokens=response['usage']['inputTokens'],
input_tokens=input_tokens,
output_tokens=response['usage']['outputTokens'],
cache_read_tokens=cache_read_tokens,
cache_write_tokens=cache_write_tokens,
)
response_id = response.get('ResponseMetadata', {}).get('RequestId', None)
return ModelResponse(
Expand Down Expand Up @@ -346,7 +352,12 @@ async def _messages_create(
'inferenceConfig': inference_config,
}

tool_config = self._map_tool_config(model_request_parameters)
tool_config = self._map_tool_config(
model_request_parameters,
should_add_cache_point=(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now always adding a cache point, independent of has_leading_cache_point. I think we have map_messages return that value if it wasn't already used by adding to the system prompt, and then use it here. That may require some reordering.

not system_prompt and BedrockModelProfile.from_profile(self.profile).bedrock_supports_prompt_caching
),
)
if tool_config:
params['toolConfig'] = tool_config

Expand Down Expand Up @@ -395,11 +406,16 @@ def _map_inference_config(

return inference_config

def _map_tool_config(self, model_request_parameters: ModelRequestParameters) -> ToolConfigurationTypeDef | None:
def _map_tool_config(
self, model_request_parameters: ModelRequestParameters, should_add_cache_point: bool = False
) -> ToolConfigurationTypeDef | None:
tools = self._get_tools(model_request_parameters)
if not tools:
return None

if should_add_cache_point:
tools[-1]['cachePoint'] = {'type': 'default'}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is causing test_bedrock_anthropic_tool_with_thinking to fail with this error:

Invalid number of parameters set for tagged union structure toolConfig.tools[0]. Can only set one of the following keys: toolSpec, cachePoint.

It sounds like the cachePoint has to be a separate item in tools, not a field on the last item.


tool_choice: ToolChoiceTypeDef
if not model_request_parameters.allow_text_output:
tool_choice = {'any': {}}
Expand Down Expand Up @@ -429,7 +445,12 @@ async def _map_messages( # noqa: C901
if isinstance(part, SystemPromptPart) and part.content:
system_prompt.append({'text': part.content})
elif isinstance(part, UserPromptPart):
bedrock_messages.extend(await self._map_user_prompt(part, document_count))
has_leading_cache_point, user_messages = await self._map_user_prompt(
part, document_count, profile.bedrock_supports_prompt_caching
)
if has_leading_cache_point:
system_prompt.append({'cachePoint': {'type': 'default'}})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only add it to the system prompt if this is the first ModelRequest -- if there were previous ModelResponses from the assistant, we should add it to the last assistant message, I think.

bedrock_messages.extend(user_messages)
elif isinstance(part, ToolReturnPart):
assert part.tool_call_id is not None
bedrock_messages.append(
Expand Down Expand Up @@ -522,13 +543,22 @@ async def _map_messages( # noqa: C901

return system_prompt, processed_messages

@staticmethod
async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int]) -> list[MessageUnionTypeDef]:
async def _map_user_prompt( # noqa: C901
self, part: UserPromptPart, document_count: Iterator[int], supports_caching: bool
) -> tuple[bool, list[MessageUnionTypeDef]]:
content: list[ContentBlockUnionTypeDef] = []
has_leading_cache_point = False

if isinstance(part.content, str):
content.append({'text': part.content})
else:
for item in part.content:
if part.content and isinstance(part.content[0], CachePoint):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can move this to the isinstance(item, CachePoint) below by changing for item in part.content to for i, item in enumerate(part.content) and then checking if i == 0 to know if it was the first element

has_leading_cache_point = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure that bedrock needs this special behavior of moving this to the system prompt? It doesn't support a cachePoint at the start of user content?

items_to_process = part.content[1:]
else:
items_to_process = part.content

for item in items_to_process:
if isinstance(item, str):
content.append({'text': item})
elif isinstance(item, BinaryContent):
Expand Down Expand Up @@ -578,11 +608,15 @@ async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int])
), f'Unsupported video format: {format}'
video: VideoBlockTypeDef = {'format': format, 'source': {'bytes': downloaded_item['data']}}
content.append({'video': video})
elif isinstance(item, CachePoint):
if supports_caching:
content.append({'cachePoint': {'type': 'default'}})
continue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this continue

elif isinstance(item, AudioUrl): # pragma: no cover
raise NotImplementedError('Audio is not supported yet.')
else:
assert_never(item)
return [{'role': 'user', 'content': content}]
return has_leading_cache_point, [{'role': 'user', 'content': content}]

@staticmethod
def _map_tool_call(t: ToolCallPart) -> ContentBlockOutputTypeDef:
Expand Down Expand Up @@ -674,9 +708,14 @@ def timestamp(self) -> datetime:
return self._timestamp

def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> usage.RequestUsage:
cache_read_tokens = metadata['usage'].get('cacheReadInputTokens', 0)
cache_write_tokens = metadata['usage'].get('cacheWriteInputTokens', 0)
input_tokens = metadata['usage']['inputTokens'] + cache_read_tokens + cache_write_tokens
return usage.RequestUsage(
input_tokens=metadata['usage']['inputTokens'],
input_tokens=input_tokens,
output_tokens=metadata['usage']['outputTokens'],
cache_write_tokens=cache_write_tokens,
cache_read_tokens=cache_read_tokens,
)


Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
BinaryContent,
BuiltinToolCallPart,
BuiltinToolReturnPart,
CachePoint,
FileUrl,
ModelMessage,
ModelRequest,
Expand Down Expand Up @@ -371,6 +372,8 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[_GeminiPartUnion]
else: # pragma: lax no cover
file_data = _GeminiFileDataPart(file_data={'file_uri': item.url, 'mime_type': item.media_type})
content.append(file_data)
elif isinstance(item, CachePoint):
raise NotImplementedError('CachePoint is not supported for Gemini')
else:
assert_never(item) # pragma: lax no cover
return content
Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
BinaryContent,
BuiltinToolCallPart,
BuiltinToolReturnPart,
CachePoint,
FileUrl,
ModelMessage,
ModelRequest,
Expand Down Expand Up @@ -514,6 +515,8 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[PartDict]:
content.append(
{'file_data': {'file_uri': item.url, 'mime_type': item.media_type}}
) # pragma: lax no cover
elif isinstance(item, CachePoint):
raise NotImplementedError('CachePoint is not supported for Google')
else:
assert_never(item)
return content
Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
BinaryContent,
BuiltinToolCallPart,
BuiltinToolReturnPart,
CachePoint,
DocumentUrl,
ImageUrl,
ModelMessage,
Expand Down Expand Up @@ -426,6 +427,8 @@ async def _map_user_prompt(part: UserPromptPart) -> ChatCompletionInputMessage:
raise NotImplementedError('DocumentUrl is not supported for Hugging Face')
elif isinstance(item, VideoUrl):
raise NotImplementedError('VideoUrl is not supported for Hugging Face')
elif isinstance(item, CachePoint):
raise NotImplementedError('CachePoint is not supported for Hugging Face')
else:
assert_never(item)
return ChatCompletionInputMessage(role='user', content=content) # type: ignore
Expand Down
5 changes: 5 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
BinaryContent,
BuiltinToolCallPart,
BuiltinToolReturnPart,
CachePoint,
DocumentUrl,
ImageUrl,
ModelMessage,
Expand Down Expand Up @@ -713,6 +714,8 @@ async def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessa
content.append(file)
elif isinstance(item, VideoUrl): # pragma: no cover
raise NotImplementedError('VideoUrl is not supported for OpenAI')
elif isinstance(item, CachePoint):
raise NotImplementedError('CachePoint is not supported for OpenAI')
else:
assert_never(item)
return chat.ChatCompletionUserMessageParam(role='user', content=content)
Expand Down Expand Up @@ -1150,6 +1153,8 @@ async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessagePa
)
elif isinstance(item, VideoUrl): # pragma: no cover
raise NotImplementedError('VideoUrl is not supported for OpenAI.')
elif isinstance(item, CachePoint):
raise NotImplementedError('CachePoint is not supported for OpenAI')
else:
assert_never(item)
return responses.EasyInputMessageParam(role='user', content=content)
Expand Down
40 changes: 32 additions & 8 deletions pydantic_ai_slim/pydantic_ai/providers/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,41 @@ class BedrockModelProfile(ModelProfile):
bedrock_supports_tool_choice: bool = False
bedrock_tool_result_format: Literal['text', 'json'] = 'text'
bedrock_send_back_thinking_parts: bool = False
bedrock_supports_prompt_caching: bool = False


# Supported models: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
ANTHROPIC_CACHING_SUPPORTED_MODELS = {'claude-3-5-sonnet', 'claude-3-5-haiku', 'claude-3-7-sonnet', 'claude-sonnet-4'}

# Supported models: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
AMAZON_CACHING_SUPPORTED_MODELS = {'nova-micro', 'nova-lite', 'nova-pro', 'nova-premier'}


def bedrock_anthropic_model_profile(model_name: str) -> ModelProfile | None:
"""Create a Bedrock model profile for Anthropic models with caching support where applicable."""
return BedrockModelProfile(
bedrock_supports_tool_choice=True,
bedrock_send_back_thinking_parts=True,
bedrock_supports_prompt_caching=any(
supported in model_name for supported in ANTHROPIC_CACHING_SUPPORTED_MODELS
),
).update(anthropic_model_profile(model_name))


def bedrock_mistral_model_profile(model_name: str) -> ModelProfile | None:
"""Create a Bedrock model profile for Mistral models."""
return BedrockModelProfile(bedrock_tool_result_format='json').update(mistral_model_profile(model_name))


def bedrock_amazon_model_profile(model_name: str) -> ModelProfile | None:
"""Get the model profile for an Amazon model used via Bedrock."""
"""Get the model profile for an Amazon model used via Bedrock with caching support where applicable."""
profile = amazon_model_profile(model_name)
if 'nova' in model_name:
return BedrockModelProfile(bedrock_supports_tool_choice=True).update(profile)
# Check if this Nova model supports prompt caching
supports_caching = any(supported in model_name for supported in AMAZON_CACHING_SUPPORTED_MODELS)
return BedrockModelProfile(
bedrock_supports_tool_choice=True, bedrock_supports_prompt_caching=supports_caching
).update(profile)
return profile


Expand All @@ -65,12 +93,8 @@ def client(self) -> BaseClient:

def model_profile(self, model_name: str) -> ModelProfile | None:
provider_to_profile: dict[str, Callable[[str], ModelProfile | None]] = {
'anthropic': lambda model_name: BedrockModelProfile(
bedrock_supports_tool_choice=True, bedrock_send_back_thinking_parts=True
).update(anthropic_model_profile(model_name)),
'mistral': lambda model_name: BedrockModelProfile(bedrock_tool_result_format='json').update(
mistral_model_profile(model_name)
),
'anthropic': bedrock_anthropic_model_profile,
'mistral': bedrock_mistral_model_profile,
'cohere': cohere_model_profile,
'amazon': bedrock_amazon_model_profile,
'meta': meta_model_profile,
Expand Down
Loading
Loading