Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import hashlib
from collections import defaultdict, deque
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
from contextlib import asynccontextmanager, contextmanager
from contextlib import asynccontextmanager, contextmanager, suppress
from contextvars import ContextVar
from dataclasses import field
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast

from genai_prices import calc_price
from opentelemetry.trace import Tracer
from typing_extensions import TypeGuard, TypeVar, assert_never

Expand Down Expand Up @@ -309,6 +310,15 @@ async def stream(
) as streamed_response:
self._did_stream = True
ctx.state.usage.requests += 1

# If we can't calculate the price, we don't want to fail the run.
with suppress(LookupError):
ctx.state.usage.cost += calc_price(
streamed_response.usage(),
ctx.deps.model.model_name,
provider_id=streamed_response.provider_name,
genai_request_timestamp=streamed_response.timestamp,
).total_price
agent_stream = result.AgentStream[DepsT, T](
streamed_response,
ctx.deps.output_schema,
Expand Down Expand Up @@ -339,6 +349,15 @@ async def _make_request(
model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
ctx.state.usage.requests += 1

# If we can't calculate the price, we don't want to fail the run.
with suppress(LookupError):
ctx.state.usage.cost += calc_price(
model_response.usage,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you use ModelResponse.price()?

ctx.deps.model.model_name,
provider_id=model_response.provider_name,
genai_request_timestamp=model_response.timestamp,
).total_price

return self._finish_handling(ctx, model_response)

async def _prepare_request(
Expand Down
12 changes: 12 additions & 0 deletions pydantic_ai_slim/pydantic_ai/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dataclasses
from copy import copy
from dataclasses import dataclass, fields
from decimal import Decimal

from typing_extensions import deprecated, overload

Expand All @@ -19,6 +20,7 @@ class UsageBase:

cache_write_tokens: int = 0
"""Number of tokens written to the cache."""

cache_read_tokens: int = 0
"""Number of tokens read from the cache."""

Expand All @@ -27,8 +29,10 @@ class UsageBase:

input_audio_tokens: int = 0
"""Number of audio input tokens."""

cache_audio_read_tokens: int = 0
"""Number of audio tokens read from the cache."""

output_audio_tokens: int = 0
"""Number of audio output tokens."""

Expand Down Expand Up @@ -122,17 +126,22 @@ class RunUsage(UsageBase):

cache_write_tokens: int = 0
"""Total number of tokens written to the cache."""

cache_read_tokens: int = 0
"""Total number of tokens read from the cache."""

input_audio_tokens: int = 0
"""Total number of audio input tokens."""

cache_audio_read_tokens: int = 0
"""Total number of audio tokens read from the cache."""

output_tokens: int = 0
"""Total number of text output/completion tokens."""

cost: Decimal = Decimal('0.0')
"""Total cost of the run."""

details: dict[str, int] = dataclasses.field(default_factory=dict)
"""Any extra details returned by the model."""

Expand Down Expand Up @@ -170,6 +179,9 @@ def _incr_usage_tokens(slf: RunUsage | RequestUsage, incr_usage: RunUsage | Requ
slf.cache_audio_read_tokens += incr_usage.cache_audio_read_tokens
slf.output_tokens += incr_usage.output_tokens

if isinstance(slf, RunUsage) and isinstance(incr_usage, RunUsage):
slf.cost += incr_usage.cost

for key, value in incr_usage.details.items():
slf.details[key] = slf.details.get(key, 0) + value

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,39 +109,39 @@ interactions:
parsed_body:
data:
- created: 0
id: llama-4-maverick-17b-128e-instruct
id: llama-4-scout-17b-16e-instruct
object: model
owned_by: Cerebras
- created: 0
id: qwen-3-32b
id: gpt-oss-120b
object: model
owned_by: Cerebras
- created: 0
id: qwen-3-235b-a22b-instruct-2507
id: qwen-3-235b-a22b-thinking-2507
object: model
owned_by: Cerebras
- created: 0
id: llama-4-scout-17b-16e-instruct
id: qwen-3-32b
object: model
owned_by: Cerebras
- created: 0
id: gpt-oss-120b
id: llama-3.3-70b
object: model
owned_by: Cerebras
- created: 0
id: qwen-3-coder-480b
object: model
owned_by: Cerebras
- created: 0
id: llama-3.3-70b
id: llama3.1-8b
object: model
owned_by: Cerebras
- created: 0
id: llama3.1-8b
id: qwen-3-235b-a22b-instruct-2507
object: model
owned_by: Cerebras
- created: 0
id: qwen-3-235b-a22b-thinking-2507
id: llama-4-maverick-17b-128e-instruct
object: model
owned_by: Cerebras
object: list
Expand Down
15 changes: 7 additions & 8 deletions tests/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Sequence
from dataclasses import dataclass, field
from datetime import datetime, timezone
from decimal import Decimal
from enum import Enum
from functools import cached_property
from typing import Annotated, Any, Callable, Literal, Union, cast
Expand Down Expand Up @@ -236,13 +237,7 @@ async def test_request_simple_usage(allow_model_requests: None):

result = await agent.run('Hello')
assert result.output == 'world'
assert result.usage() == snapshot(
RunUsage(
requests=1,
input_tokens=2,
output_tokens=1,
)
)
assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=2, output_tokens=1, cost=Decimal('0.000015')))


async def test_request_structured_response(allow_model_requests: None):
Expand Down Expand Up @@ -423,7 +418,9 @@ async def get_location(loc_name: str) -> str:
),
]
)
assert result.usage() == snapshot(RunUsage(requests=3, cache_read_tokens=3, input_tokens=5, output_tokens=3))
assert result.usage() == snapshot(
RunUsage(requests=3, cache_read_tokens=3, input_tokens=5, output_tokens=3, cost=Decimal('0.00004625'))
)


FinishReason = Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call']
Expand Down Expand Up @@ -826,6 +823,7 @@ async def test_openai_audio_url_input(allow_model_requests: None, openai_api_key
'text_tokens': 72,
},
requests=1,
cost=Decimal('0.0008925'),
)
)

Expand Down Expand Up @@ -1024,6 +1022,7 @@ async def test_audio_as_binary_content_input(
'text_tokens': 9,
},
requests=1,
cost=Decimal('0.00020'),
)
)

Expand Down
Loading