Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/judgeval/tracer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,31 @@ def wrap(client: ApiClient) -> ApiClient:
stacklevel=2,
)

# Check if the client is a TrainableModel instance
try:
from judgeval.trainer.trainable_model import TrainableModel

if isinstance(client, TrainableModel):
# Register wrapper functions for all active tracers
def wrapper_func(model_instance):
wrapped_instance = model_instance
for tracer in Tracer._active_tracers:
wrapped_instance = wrap_provider(tracer, wrapped_instance)
return wrapped_instance

client._register_tracer_wrapper(wrapper_func)

# Wrap the current model instance with all active tracers
current_model = client.get_current_model()
if current_model:
for tracer in Tracer._active_tracers:
wrap_provider(tracer, current_model)

return client
except ImportError:
# TrainableModel not available, continue with normal wrapping
pass

wrapped_client = client
for tracer in Tracer._active_tracers:
wrapped_client = tracer.wrap(wrapped_client)
Expand Down
226 changes: 226 additions & 0 deletions src/judgeval/tracer/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
HAS_ANTHROPIC,
HAS_GOOGLE_GENAI,
HAS_GROQ,
HAS_FIREWORKS,
ApiClient,
)
from judgeval.tracer.managers import sync_span_context, async_span_context
Expand All @@ -43,6 +44,7 @@ class ProviderType(Enum):
TOGETHER = "together"
GOOGLE = "google"
GROQ = "groq"
FIREWORKS = "fireworks"
DEFAULT = "default"


Expand Down Expand Up @@ -118,6 +120,22 @@ def _detect_provider(client: ApiClient) -> ProviderType:
if isinstance(client, (groq_Groq, groq_AsyncGroq)):
return ProviderType.GROQ

if HAS_FIREWORKS:
from judgeval.tracer.llm.providers import (
fireworks_Client,
fireworks_AsyncClient,
fireworks_LLM,
)

# Check for fireworks.client classes (Fireworks, AsyncFireworks)
if fireworks_Client is not None and fireworks_AsyncClient is not None:
if isinstance(client, (fireworks_Client, fireworks_AsyncClient)):
return ProviderType.FIREWORKS

if fireworks_LLM is not None:
if isinstance(client, fireworks_LLM):
return ProviderType.FIREWORKS

return ProviderType.DEFAULT


Expand Down Expand Up @@ -165,6 +183,19 @@ def _extract_groq_content(chunk) -> str:
return ""


def _extract_fireworks_content(chunk) -> str:
"""Extract content from Fireworks streaming chunk."""
if (
hasattr(chunk, "choices")
and chunk.choices
and hasattr(chunk.choices[0], "delta")
):
delta_content = getattr(chunk.choices[0].delta, "content", None)
if delta_content:
return delta_content
return ""


# Provider-specific chunk usage extraction handlers
def _extract_openai_chunk_usage(chunk) -> Any:
"""Extract usage data from OpenAI streaming chunk."""
Expand Down Expand Up @@ -203,6 +234,13 @@ def _extract_groq_chunk_usage(chunk) -> Any:
return None


def _extract_fireworks_chunk_usage(chunk) -> Any:
"""Extract usage data from Fireworks streaming chunk."""
if hasattr(chunk, "usage") and chunk.usage:
return chunk.usage
return None


# Provider-specific token extraction handlers
def _extract_openai_tokens(usage_data) -> tuple[int, int, int, int]:
"""Extract token counts from OpenAI usage data."""
Expand Down Expand Up @@ -296,6 +334,22 @@ def _extract_groq_tokens(usage_data) -> tuple[int, int, int, int]:
return prompt_tokens, completion_tokens, cache_read_input_tokens, 0


def _extract_fireworks_tokens(usage_data) -> tuple[int, int, int, int]:
"""Extract token counts from Fireworks usage data."""
prompt_tokens = (
usage_data.prompt_tokens
if hasattr(usage_data, "prompt_tokens") and usage_data.prompt_tokens is not None
else 0
)
completion_tokens = (
usage_data.completion_tokens
if hasattr(usage_data, "completion_tokens")
and usage_data.completion_tokens is not None
else 0
)
return prompt_tokens, completion_tokens, 0, 0


# Provider-specific output formatting handlers
def _format_openai_output(response: Any) -> tuple[Optional[str], Optional[TraceUsage]]:
"""Format output data from OpenAI response."""
Expand Down Expand Up @@ -543,6 +597,35 @@ def _format_groq_output(response: Any) -> tuple[Optional[str], Optional[TraceUsa
return None, None


def _format_fireworks_output(
response: Any,
) -> tuple[Optional[str], Optional[TraceUsage]]:
"""Format output data from Fireworks response."""
model_name = getattr(response, "model", "") or ""
usage = getattr(response, "usage", None)
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
completion_tokens = getattr(usage, "completion_tokens", 0) or 0

message_content = None
choices = getattr(response, "choices", None)
if choices:
message = getattr(choices[0], "message", None)
if message:
message_content = getattr(message, "content", None)

if model_name:
model_name = "fireworks_ai/" + model_name
return message_content, _create_usage(
model_name,
prompt_tokens,
completion_tokens,
0,
0,
)

return None, None


class _TracedGeneratorBase:
"""Base class with common logic for parsing stream chunks."""

Expand Down Expand Up @@ -588,6 +671,8 @@ def _extract_content(self, chunk) -> str:
return _extract_together_content(chunk)
elif self.provider_type == ProviderType.GROQ:
return _extract_groq_content(chunk)
elif self.provider_type == ProviderType.FIREWORKS:
return _extract_fireworks_content(chunk)
else:
# Default case - assume OpenAI-compatible for unknown providers
return _extract_openai_content(chunk)
Expand Down Expand Up @@ -796,6 +881,8 @@ def _extract_chunk_usage(client: ApiClient, chunk) -> Any:
return _extract_together_chunk_usage(chunk)
elif provider_type == ProviderType.GROQ:
return _extract_groq_chunk_usage(chunk)
elif provider_type == ProviderType.FIREWORKS:
return _extract_fireworks_chunk_usage(chunk)
else:
# Default case - assume OpenAI-compatible for unknown providers
return _extract_openai_chunk_usage(chunk)
Expand All @@ -813,6 +900,8 @@ def _extract_usage_tokens(client: ApiClient, usage_data) -> tuple[int, int, int,
return _extract_together_tokens(usage_data)
elif provider_type == ProviderType.GROQ:
return _extract_groq_tokens(usage_data)
elif provider_type == ProviderType.FIREWORKS:
return _extract_fireworks_tokens(usage_data)
else:
# Default case - assume OpenAI-compatible for unknown providers
return _extract_openai_tokens(usage_data)
Expand Down Expand Up @@ -846,6 +935,12 @@ def _process_usage_data(
and not final_model_name.startswith("groq/")
):
final_model_name = "groq/" + final_model_name
elif (
provider_type == ProviderType.FIREWORKS
and final_model_name
and not final_model_name.startswith("fireworks_ai/")
):
final_model_name = "fireworks_ai/" + final_model_name

usage = _create_usage(
final_model_name,
Expand Down Expand Up @@ -906,6 +1001,30 @@ def wrapper(*args, **kwargs):
):
model_name = "groq/" + model_name

# Add provider prefix for Fireworks clients
if HAS_FIREWORKS:
from judgeval.tracer.llm.providers import (
fireworks_Client,
fireworks_AsyncClient,
fireworks_LLM,
)

fireworks_instances = []
if fireworks_Client is not None:
fireworks_instances.append(fireworks_Client)
if fireworks_AsyncClient is not None:
fireworks_instances.append(fireworks_AsyncClient)
if fireworks_LLM is not None:
fireworks_instances.append(fireworks_LLM)

if (
fireworks_instances
and isinstance(client, tuple(fireworks_instances))
and model_name
and not model_name.startswith("fireworks_ai/")
):
model_name = "fireworks_ai/" + model_name

response = function(*args, **kwargs)
return TracedGenerator(tracer, response, client, span, model_name)
else:
Expand Down Expand Up @@ -955,6 +1074,30 @@ async def wrapper(*args, **kwargs):
):
model_name = "groq/" + model_name

# Add provider prefix for Fireworks clients
if HAS_FIREWORKS:
from judgeval.tracer.llm.providers import (
fireworks_Client,
fireworks_AsyncClient,
fireworks_LLM,
)

fireworks_instances = []
if fireworks_Client is not None:
fireworks_instances.append(fireworks_Client)
if fireworks_AsyncClient is not None:
fireworks_instances.append(fireworks_AsyncClient)
if fireworks_LLM is not None:
fireworks_instances.append(fireworks_LLM)

if (
fireworks_instances
and isinstance(client, tuple(fireworks_instances))
and model_name
and not model_name.startswith("fireworks_ai/")
):
model_name = "fireworks_ai/" + model_name

response = await function(*args, **kwargs)
return TracedAsyncGenerator(tracer, response, client, span, model_name)
else:
Expand Down Expand Up @@ -1005,6 +1148,20 @@ def wrapper(*args, **kwargs):
):
model_name = "groq/" + model_name

# Add provider prefix for Fireworks clients
if HAS_FIREWORKS:
from judgeval.tracer.llm.providers import (
fireworks_Client,
fireworks_AsyncClient,
)

if (
isinstance(client, (fireworks_Client, fireworks_AsyncClient))
and model_name
and not model_name.startswith("fireworks_ai/")
):
model_name = "fireworks_ai/" + model_name

original_context_manager = function(*args, **kwargs)
return TracedSyncContextManager(
tracer, original_context_manager, client, span, model_name
Expand Down Expand Up @@ -1037,6 +1194,20 @@ def wrapper(*args, **kwargs):
):
model_name = "groq/" + model_name

# Add provider prefix for Fireworks clients
if HAS_FIREWORKS:
from judgeval.tracer.llm.providers import (
fireworks_Client,
fireworks_AsyncClient,
)

if (
isinstance(client, (fireworks_Client, fireworks_AsyncClient))
and model_name
and not model_name.startswith("fireworks_ai/")
):
model_name = "fireworks_ai/" + model_name

original_context_manager = function(*args, **kwargs)
return TracedAsyncContextManager(
tracer, original_context_manager, client, span, model_name
Expand Down Expand Up @@ -1176,6 +1347,59 @@ def wrapper(*args, **kwargs):
wrapped_async(client.chat.completions.create, span_name),
)

if HAS_FIREWORKS:
from judgeval.tracer.llm.providers import (
fireworks_Client,
fireworks_AsyncClient,
fireworks_LLM,
)

span_name = "FIREWORKS_API_CALL"

# Handle fireworks.client classes (Fireworks, AsyncFireworks)
if fireworks_Client is not None and fireworks_AsyncClient is not None:
if isinstance(client, fireworks_Client):
setattr(
client.chat.completions,
"create",
wrapped(client.chat.completions.create, span_name),
)
elif isinstance(client, fireworks_AsyncClient):
setattr(
client.chat.completions,
"create",
wrapped_async(client.chat.completions.create, span_name),
)

if fireworks_LLM is not None and isinstance(client, fireworks_LLM):
if hasattr(client, "chat") and hasattr(client.chat, "completions"):
if hasattr(client.chat.completions, "create"):
setattr(
client.chat.completions,
"create",
wrapped(client.chat.completions.create, span_name),
)
if hasattr(client.chat.completions, "acreate"):
setattr(
client.chat.completions,
"acreate",
wrapped_async(client.chat.completions.acreate, span_name),
)

if hasattr(client, "completions"):
if hasattr(client.completions, "create"):
setattr(
client.completions,
"create",
wrapped(client.completions.create, span_name),
)
if hasattr(client.completions, "acreate"):
setattr(
client.completions,
"acreate",
wrapped_async(client.completions.acreate, span_name),
)

return client


Expand All @@ -1195,6 +1419,8 @@ def _format_output_data(
return _format_google_output(response)
elif provider_type == ProviderType.GROQ:
return _format_groq_output(response)
elif provider_type == ProviderType.FIREWORKS:
return _format_fireworks_output(response)
else:
# Default case - assume OpenAI-compatible for unknown providers
judgeval_logger.info(
Expand Down
Loading
Loading