|
21 | 21 | BinaryContent,
|
22 | 22 | BuiltinToolCallPart,
|
23 | 23 | BuiltinToolReturnPart,
|
| 24 | + CachePoint, |
24 | 25 | DocumentUrl,
|
25 | 26 | ImageUrl,
|
26 | 27 | ModelMessage,
|
@@ -296,9 +297,14 @@ async def _process_response(self, response: ConverseResponseTypeDef) -> ModelRes
|
296 | 297 | tool_call_id=tool_use['toolUseId'],
|
297 | 298 | ),
|
298 | 299 | )
|
| 300 | + cache_read_tokens = response['usage'].get('cacheReadInputTokens', 0) |
| 301 | + cache_write_tokens = response['usage'].get('cacheWriteInputTokens', 0) |
| 302 | + input_tokens = response['usage']['inputTokens'] + cache_read_tokens + cache_write_tokens |
299 | 303 | u = usage.RequestUsage(
|
300 |
| - input_tokens=response['usage']['inputTokens'], |
| 304 | + input_tokens=input_tokens, |
301 | 305 | output_tokens=response['usage']['outputTokens'],
|
| 306 | + cache_read_tokens=cache_read_tokens, |
| 307 | + cache_write_tokens=cache_write_tokens, |
302 | 308 | )
|
303 | 309 | response_id = response.get('ResponseMetadata', {}).get('RequestId', None)
|
304 | 310 | return ModelResponse(
|
@@ -346,7 +352,12 @@ async def _messages_create(
|
346 | 352 | 'inferenceConfig': inference_config,
|
347 | 353 | }
|
348 | 354 |
|
349 |
| - tool_config = self._map_tool_config(model_request_parameters) |
| 355 | + tool_config = self._map_tool_config( |
| 356 | + model_request_parameters, |
| 357 | + should_add_cache_point=( |
| 358 | + not system_prompt and BedrockModelProfile.from_profile(self.profile).bedrock_supports_prompt_caching |
| 359 | + ), |
| 360 | + ) |
350 | 361 | if tool_config:
|
351 | 362 | params['toolConfig'] = tool_config
|
352 | 363 |
|
@@ -395,11 +406,16 @@ def _map_inference_config(
|
395 | 406 |
|
396 | 407 | return inference_config
|
397 | 408 |
|
398 |
| - def _map_tool_config(self, model_request_parameters: ModelRequestParameters) -> ToolConfigurationTypeDef | None: |
| 409 | + def _map_tool_config( |
| 410 | + self, model_request_parameters: ModelRequestParameters, should_add_cache_point: bool = False |
| 411 | + ) -> ToolConfigurationTypeDef | None: |
399 | 412 | tools = self._get_tools(model_request_parameters)
|
400 | 413 | if not tools:
|
401 | 414 | return None
|
402 | 415 |
|
| 416 | + if should_add_cache_point: |
| 417 | + tools[-1]['cachePoint'] = {'type': 'default'} |
| 418 | + |
403 | 419 | tool_choice: ToolChoiceTypeDef
|
404 | 420 | if not model_request_parameters.allow_text_output:
|
405 | 421 | tool_choice = {'any': {}}
|
@@ -429,7 +445,12 @@ async def _map_messages( # noqa: C901
|
429 | 445 | if isinstance(part, SystemPromptPart) and part.content:
|
430 | 446 | system_prompt.append({'text': part.content})
|
431 | 447 | elif isinstance(part, UserPromptPart):
|
432 |
| - bedrock_messages.extend(await self._map_user_prompt(part, document_count)) |
| 448 | + has_leading_cache_point, user_messages = await self._map_user_prompt( |
| 449 | + part, document_count, profile.bedrock_supports_prompt_caching |
| 450 | + ) |
| 451 | + if has_leading_cache_point: |
| 452 | + system_prompt.append({'cachePoint': {'type': 'default'}}) |
| 453 | + bedrock_messages.extend(user_messages) |
433 | 454 | elif isinstance(part, ToolReturnPart):
|
434 | 455 | assert part.tool_call_id is not None
|
435 | 456 | bedrock_messages.append(
|
@@ -522,13 +543,22 @@ async def _map_messages( # noqa: C901
|
522 | 543 |
|
523 | 544 | return system_prompt, processed_messages
|
524 | 545 |
|
525 |
| - @staticmethod |
526 |
| - async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int]) -> list[MessageUnionTypeDef]: |
| 546 | + async def _map_user_prompt( # noqa: C901 |
| 547 | + self, part: UserPromptPart, document_count: Iterator[int], supports_caching: bool |
| 548 | + ) -> tuple[bool, list[MessageUnionTypeDef]]: |
527 | 549 | content: list[ContentBlockUnionTypeDef] = []
|
| 550 | + has_leading_cache_point = False |
| 551 | + |
528 | 552 | if isinstance(part.content, str):
|
529 | 553 | content.append({'text': part.content})
|
530 | 554 | else:
|
531 |
| - for item in part.content: |
| 555 | + if part.content and isinstance(part.content[0], CachePoint): |
| 556 | + has_leading_cache_point = True |
| 557 | + items_to_process = part.content[1:] |
| 558 | + else: |
| 559 | + items_to_process = part.content |
| 560 | + |
| 561 | + for item in items_to_process: |
532 | 562 | if isinstance(item, str):
|
533 | 563 | content.append({'text': item})
|
534 | 564 | elif isinstance(item, BinaryContent):
|
@@ -578,11 +608,15 @@ async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int])
|
578 | 608 | ), f'Unsupported video format: {format}'
|
579 | 609 | video: VideoBlockTypeDef = {'format': format, 'source': {'bytes': downloaded_item['data']}}
|
580 | 610 | content.append({'video': video})
|
| 611 | + elif isinstance(item, CachePoint): |
| 612 | + if supports_caching: |
| 613 | + content.append({'cachePoint': {'type': 'default'}}) |
| 614 | + continue |
581 | 615 | elif isinstance(item, AudioUrl): # pragma: no cover
|
582 | 616 | raise NotImplementedError('Audio is not supported yet.')
|
583 | 617 | else:
|
584 | 618 | assert_never(item)
|
585 |
| - return [{'role': 'user', 'content': content}] |
| 619 | + return has_leading_cache_point, [{'role': 'user', 'content': content}] |
586 | 620 |
|
587 | 621 | @staticmethod
|
588 | 622 | def _map_tool_call(t: ToolCallPart) -> ContentBlockOutputTypeDef:
|
@@ -674,9 +708,14 @@ def timestamp(self) -> datetime:
|
674 | 708 | return self._timestamp
|
675 | 709 |
|
676 | 710 | def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> usage.RequestUsage:
|
| 711 | + cache_read_tokens = metadata['usage'].get('cacheReadInputTokens', 0) |
| 712 | + cache_write_tokens = metadata['usage'].get('cacheWriteInputTokens', 0) |
| 713 | + input_tokens = metadata['usage']['inputTokens'] + cache_read_tokens + cache_write_tokens |
677 | 714 | return usage.RequestUsage(
|
678 |
| - input_tokens=metadata['usage']['inputTokens'], |
| 715 | + input_tokens=input_tokens, |
679 | 716 | output_tokens=metadata['usage']['outputTokens'],
|
| 717 | + cache_write_tokens=cache_write_tokens, |
| 718 | + cache_read_tokens=cache_read_tokens, |
680 | 719 | )
|
681 | 720 |
|
682 | 721 |
|
|
0 commit comments