Skip to content

Commit be52c68

Browse files
sjrlanakin87
andauthored
refactor: Refactor Agent logic for easier readability (#9726)
* Start refactor * Update run_async to use the new code * Slight updates * Refactoring of tests * Remove messages from execution context * Cleanup * More cleanup * Formatting * Fix some typing * ignore typing issues * Add reno * Adding docstrings * Small changes * docstrings * Updates * Update haystack/components/agents/agent.py Co-authored-by: Stefano Fiorucci <[email protected]> * PR comments * PR comments --------- Co-authored-by: Stefano Fiorucci <[email protected]>
1 parent 3008056 commit be52c68

File tree

9 files changed

+631
-732
lines changed

9 files changed

+631
-732
lines changed

haystack/components/agents/agent.py

Lines changed: 327 additions & 277 deletions
Large diffs are not rendered by default.

haystack/core/pipeline/breakpoint.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -340,15 +340,15 @@ def _validate_tool_breakpoint_is_valid(agent_breakpoint: AgentBreakpoint, tools:
340340
raise ValueError(f"Tool '{tool_breakpoint.tool_name}' is not available in the agent's tools")
341341

342342

343-
def _check_chat_generator_breakpoint(
343+
def _trigger_chat_generator_breakpoint(
344344
*, agent_snapshot: AgentSnapshot, parent_snapshot: Optional[PipelineSnapshot]
345345
) -> None:
346346
"""
347-
Check for breakpoint before calling the ChatGenerator.
347+
Trigger a breakpoint before ChatGenerator execution in Agent.
348348
349349
:param agent_snapshot: AgentSnapshot object containing the agent's state and breakpoints
350350
:param parent_snapshot: Optional parent snapshot containing the state of the pipeline that houses the agent.
351-
:raises BreakpointException: If a breakpoint is triggered
351+
:raises BreakpointException: Always raised when this function is called, indicating a breakpoint has been triggered.
352352
"""
353353

354354
break_point = agent_snapshot.break_point.break_point
@@ -381,16 +381,17 @@ def _check_chat_generator_breakpoint(
381381
)
382382

383383

384-
def _check_tool_invoker_breakpoint(
384+
def _handle_tool_invoker_breakpoint(
385385
*, llm_messages: list[ChatMessage], agent_snapshot: AgentSnapshot, parent_snapshot: Optional[PipelineSnapshot]
386386
) -> None:
387387
"""
388-
Check for breakpoint before calling the ToolInvoker.
388+
Check if a tool call breakpoint should be triggered before executing the tool invoker.
389389
390390
:param llm_messages: List of ChatMessage objects containing potential tool calls.
391391
:param agent_snapshot: AgentSnapshot object containing the agent's state and breakpoints.
392392
:param parent_snapshot: Optional parent snapshot containing the state of the pipeline that houses the agent.
393-
:raises BreakpointException: If a breakpoint is triggered
393+
:raises BreakpointException: If the breakpoint is triggered, indicating a breakpoint has been reached for a tool
394+
call.
394395
"""
395396
if not isinstance(agent_snapshot.break_point.break_point, ToolBreakpoint):
396397
return
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
enhancements:
3+
- |
4+
The internal Agent logic was refactored to help with readability and maintanability. This should help developers understand and extend the internal Agent logic moving forward.

test/components/agents/test_agent.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,12 @@
1212
from openai import Stream
1313
from openai.types.chat import ChatCompletionChunk, chat_completion_chunk
1414

15-
from haystack import Pipeline, tracing
15+
from haystack import Pipeline, component, tracing
1616
from haystack.components.agents import Agent
1717
from haystack.components.agents.state import merge_lists
1818
from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder
1919
from haystack.components.builders.prompt_builder import PromptBuilder
2020
from haystack.components.generators.chat.openai import OpenAIChatGenerator
21-
from haystack.components.generators.chat.types import ChatGenerator
2221
from haystack.core.component.types import OutputSocket
2322
from haystack.dataclasses import ChatMessage, ToolCall
2423
from haystack.dataclasses.chat_message import ChatRole, TextContent
@@ -101,59 +100,56 @@ def openai_mock_chat_completion_chunk():
101100
yield mock_chat_completion_create
102101

103102

104-
class MockChatGeneratorWithoutTools(ChatGenerator):
103+
@component
104+
class MockChatGeneratorWithoutTools:
105105
"""A mock chat generator that implements ChatGenerator protocol but doesn't support tools."""
106106

107-
__haystack_input__ = MagicMock(_sockets_dict={})
108-
__haystack_output__ = MagicMock(_sockets_dict={})
109-
110107
def to_dict(self) -> dict[str, Any]:
111108
return {"type": "MockChatGeneratorWithoutTools", "data": {}}
112109

113110
@classmethod
114111
def from_dict(cls, data: dict[str, Any]) -> "MockChatGeneratorWithoutTools":
115112
return cls()
116113

114+
@component.output_types(replies=list[ChatMessage])
117115
def run(self, messages: list[ChatMessage]) -> dict[str, Any]:
118116
return {"replies": [ChatMessage.from_assistant("Hello")]}
119117

120118

121-
class MockChatGeneratorWithoutRunAsync(ChatGenerator):
119+
@component
120+
class MockChatGeneratorWithoutRunAsync:
122121
"""A mock chat generator that implements ChatGenerator protocol but doesn't have run_async method."""
123122

124-
__haystack_input__ = MagicMock(_sockets_dict={})
125-
__haystack_output__ = MagicMock(_sockets_dict={})
126-
127123
def to_dict(self) -> dict[str, Any]:
128124
return {"type": "MockChatGeneratorWithoutRunAsync", "data": {}}
129125

130126
@classmethod
131127
def from_dict(cls, data: dict[str, Any]) -> "MockChatGeneratorWithoutRunAsync":
132128
return cls()
133129

130+
@component.output_types(replies=list[ChatMessage])
134131
def run(
135132
self, messages: list[ChatMessage], tools: Optional[Union[list[Tool], Toolset]] = None, **kwargs
136133
) -> dict[str, Any]:
137134
return {"replies": [ChatMessage.from_assistant("Hello")]}
138135

139136

140-
class MockChatGeneratorWithRunAsync(ChatGenerator):
141-
__haystack_supports_async__ = True
142-
__haystack_input__ = MagicMock(_sockets_dict={})
143-
__haystack_output__ = MagicMock(_sockets_dict={})
144-
137+
@component
138+
class MockChatGenerator:
145139
def to_dict(self) -> dict[str, Any]:
146140
return {"type": "MockChatGeneratorWithoutRunAsync", "data": {}}
147141

148142
@classmethod
149143
def from_dict(cls, data: dict[str, Any]) -> "MockChatGeneratorWithoutRunAsync":
150144
return cls()
151145

146+
@component.output_types(replies=list[ChatMessage])
152147
def run(
153148
self, messages: list[ChatMessage], tools: Optional[Union[list[Tool], Toolset]] = None, **kwargs
154149
) -> dict[str, Any]:
155150
return {"replies": [ChatMessage.from_assistant("Hello")]}
156151

152+
@component.output_types(replies=list[ChatMessage])
157153
async def run_async(
158154
self, messages: list[ChatMessage], tools: Optional[Union[list[Tool], Toolset]] = None, **kwargs
159155
) -> dict[str, Any]:
@@ -818,7 +814,7 @@ async def test_run_async_falls_back_to_run_when_chat_generator_has_no_run_async(
818814

819815
@pytest.mark.asyncio
820816
async def test_run_async_uses_chat_generator_run_async_when_available(self, weather_tool):
821-
chat_generator = MockChatGeneratorWithRunAsync()
817+
chat_generator = MockChatGenerator()
822818
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
823819
agent.warm_up()
824820

@@ -904,16 +900,16 @@ def test_agent_tracing_span_run(self, caplog, monkeypatch, weather_tool):
904900
"chat_generator",
905901
"MockChatGeneratorWithoutRunAsync",
906902
'{"messages": "list", "tools": "list"}',
907-
"{}",
908-
"{}",
903+
'{"messages": {"type": "list", "senders": []}, "tools": {"type": "typing.Union[list[haystack.tools.tool.Tool], haystack.tools.toolset.Toolset, NoneType]", "senders": []}}', # noqa: E501
904+
'{"replies": {"type": "list", "receivers": []}}',
909905
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "tools": [{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]}', # noqa: E501
910906
1,
911907
'{"replies": [{"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello"}]}]}',
912908
100,
913909
'[{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]', # noqa: E501
914910
'["text"]',
915911
'{"messages": {"type": "list[haystack.dataclasses.chat_message.ChatMessage]", "handler": "haystack.components.agents.state.state_utils.merge_lists"}}', # noqa: E501
916-
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "streaming_callback": null}', # noqa: E501
912+
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "streaming_callback": null, "break_point": null, "snapshot": null}', # noqa: E501
917913
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}, {"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello"}]}]}', # noqa: E501
918914
1,
919915
]
@@ -927,7 +923,7 @@ def test_agent_tracing_span_run(self, caplog, monkeypatch, weather_tool):
927923

928924
@pytest.mark.asyncio
929925
async def test_agent_tracing_span_async_run(self, caplog, monkeypatch, weather_tool):
930-
chat_generator = MockChatGeneratorWithRunAsync()
926+
chat_generator = MockChatGenerator()
931927
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
932928

933929
tracing.tracer.is_content_tracing_enabled = True
@@ -962,18 +958,18 @@ async def test_agent_tracing_span_async_run(self, caplog, monkeypatch, weather_t
962958

963959
expected_tag_values = [
964960
"chat_generator",
965-
"MockChatGeneratorWithRunAsync",
961+
"MockChatGenerator",
966962
'{"messages": "list", "tools": "list"}',
967-
"{}",
968-
"{}",
963+
'{"messages": {"type": "list", "senders": []}, "tools": {"type": "typing.Union[list[haystack.tools.tool.Tool], haystack.tools.toolset.Toolset, NoneType]", "senders": []}}', # noqa: E501
964+
'{"replies": {"type": "list", "receivers": []}}',
969965
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "tools": [{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]}', # noqa: E501
970966
1,
971967
'{"replies": [{"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello from run_async"}]}]}', # noqa: E501
972968
100,
973969
'[{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]', # noqa: E501
974970
'["text"]',
975971
'{"messages": {"type": "list[haystack.dataclasses.chat_message.ChatMessage]", "handler": "haystack.components.agents.state.state_utils.merge_lists"}}', # noqa: E501
976-
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "streaming_callback": null}', # noqa: E501
972+
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "streaming_callback": null, "break_point": null, "snapshot": null}', # noqa: E501
977973
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}, {"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello from run_async"}]}]}', # noqa: E501
978974
1,
979975
]

0 commit comments

Comments
 (0)