Skip to content

Commit 4644785

Browse files
authored
refactor: Some refactoring and finish TODO for component tool (#9751)
* Some refactoring and finish TODO for component tool * Add docstrings
1 parent 6c7ae8f commit 4644785

File tree

4 files changed

+195
-171
lines changed

4 files changed

+195
-171
lines changed

haystack/tools/component_tool.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from haystack.tools.errors import SchemaGenerationError
1919
from haystack.tools.from_function import _remove_title_from_schema
2020
from haystack.tools.parameters_schema_utils import _get_component_param_descriptions, _resolve_type
21+
from haystack.tools.tool import _deserialize_outputs_to_state, _serialize_outputs_to_state
2122
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
2223

2324
logger = logging.getLogger(__name__)
@@ -46,6 +47,8 @@ class ComponentTool(Tool):
4647
You can create a ComponentTool from the component by passing the component to the ComponentTool constructor.
4748
Below is an example of creating a ComponentTool from an existing SerperDevWebSearch component.
4849
50+
## Usage Example:
51+
4952
```python
5053
from haystack import component, Pipeline
5154
from haystack.tools import ComponentTool
@@ -93,7 +96,7 @@ def __init__(
9396
outputs_to_string: Optional[dict[str, Union[str, Callable[[Any], str]]]] = None,
9497
inputs_from_state: Optional[dict[str, str]] = None,
9598
outputs_to_state: Optional[dict[str, dict[str, Union[str, Callable]]]] = None,
96-
):
99+
) -> None:
97100
"""
98101
Create a Tool instance from a Haystack component.
99102
@@ -213,10 +216,8 @@ def to_dict(self) -> dict[str, Any]:
213216
"""
214217
Serializes the ComponentTool to a dictionary.
215218
"""
216-
serialized_component = component_to_dict(obj=self._component, name=self.name)
217-
218219
serialized: dict[str, Any] = {
219-
"component": serialized_component,
220+
"component": component_to_dict(obj=self._component, name=self.name),
220221
"name": self.name,
221222
"description": self.description,
222223
"parameters": self._unresolved_parameters,
@@ -226,13 +227,7 @@ def to_dict(self) -> dict[str, Any]:
226227
}
227228

228229
if self.outputs_to_state is not None:
229-
serialized_outputs = {}
230-
for key, config in self.outputs_to_state.items():
231-
serialized_config = config.copy()
232-
if "handler" in config:
233-
serialized_config["handler"] = serialize_callable(config["handler"])
234-
serialized_outputs[key] = serialized_config
235-
serialized["outputs_to_state"] = serialized_outputs
230+
serialized["outputs_to_state"] = _serialize_outputs_to_state(self.outputs_to_state)
236231

237232
if self.outputs_to_string is not None and self.outputs_to_string.get("handler") is not None:
238233
# This is soft-copied as to not modify the attributes in place
@@ -253,13 +248,7 @@ def from_dict(cls, data: dict[str, Any]) -> "Tool":
253248
component = component_from_dict(cls=component_class, data=inner_data["component"], name=inner_data["name"])
254249

255250
if "outputs_to_state" in inner_data and inner_data["outputs_to_state"]:
256-
deserialized_outputs = {}
257-
for key, config in inner_data["outputs_to_state"].items():
258-
deserialized_config = config.copy()
259-
if "handler" in config:
260-
deserialized_config["handler"] = deserialize_callable(config["handler"])
261-
deserialized_outputs[key] = deserialized_config
262-
inner_data["outputs_to_state"] = deserialized_outputs
251+
inner_data["outputs_to_state"] = _deserialize_outputs_to_state(inner_data["outputs_to_state"])
263252

264253
if (
265254
inner_data.get("outputs_to_string") is not None

haystack/tools/tool.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from haystack.core.serialization import generate_qualified_class_name
1212
from haystack.tools.errors import ToolInvocationError
13+
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
1314

1415

1516
@dataclass
@@ -116,24 +117,18 @@ def to_dict(self) -> dict[str, Any]:
116117
:returns:
117118
Dictionary with serialized data.
118119
"""
119-
# Import here to avoid circular dependency with utils.callable_serialization
120-
from haystack.utils import serialize_callable
121-
122120
data = asdict(self)
123121
data["function"] = serialize_callable(self.function)
124122

125-
# Serialize output handlers if they exist
126-
if self.outputs_to_state:
127-
serialized_outputs = {}
128-
for key, config in self.outputs_to_state.items():
129-
serialized_config = config.copy()
130-
if "handler" in config:
131-
serialized_config["handler"] = serialize_callable(config["handler"])
132-
serialized_outputs[key] = serialized_config
133-
data["outputs_to_state"] = serialized_outputs
123+
if self.outputs_to_state is not None:
124+
data["outputs_to_state"] = _serialize_outputs_to_state(self.outputs_to_state)
134125

135126
if self.outputs_to_string is not None and self.outputs_to_string.get("handler") is not None:
127+
# This is soft-copied as to not modify the attributes in place
128+
data["outputs_to_string"] = self.outputs_to_string.copy()
136129
data["outputs_to_string"]["handler"] = serialize_callable(self.outputs_to_string["handler"])
130+
else:
131+
data["outputs_to_string"] = None
137132

138133
return {"type": generate_qualified_class_name(type(self)), "data": data}
139134

@@ -147,21 +142,10 @@ def from_dict(cls, data: dict[str, Any]) -> "Tool":
147142
:returns:
148143
Deserialized Tool.
149144
"""
150-
# Import here to avoid circular dependency with utils.callable_serialization
151-
from haystack.utils import deserialize_callable
152-
153145
init_parameters = data["data"]
154146
init_parameters["function"] = deserialize_callable(init_parameters["function"])
155-
156-
# Deserialize output handlers if they exist
157147
if "outputs_to_state" in init_parameters and init_parameters["outputs_to_state"]:
158-
deserialized_outputs = {}
159-
for key, config in init_parameters["outputs_to_state"].items():
160-
deserialized_config = config.copy()
161-
if "handler" in config:
162-
deserialized_config["handler"] = deserialize_callable(config["handler"])
163-
deserialized_outputs[key] = deserialized_config
164-
init_parameters["outputs_to_state"] = deserialized_outputs
148+
init_parameters["outputs_to_state"] = _deserialize_outputs_to_state(init_parameters["outputs_to_state"])
165149

166150
if (
167151
init_parameters.get("outputs_to_string") is not None
@@ -187,3 +171,35 @@ def _check_duplicate_tool_names(tools: Optional[list[Tool]]) -> None:
187171
duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1}
188172
if duplicate_tool_names:
189173
raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}")
174+
175+
176+
def _serialize_outputs_to_state(outputs_to_state: dict[str, dict[str, Any]]) -> dict[str, dict[str, Any]]:
177+
"""
178+
Serializes the outputs_to_state dictionary, converting any callable handlers to their string representation.
179+
180+
:param outputs_to_state: The outputs_to_state dictionary to serialize.
181+
:returns: The serialized outputs_to_state dictionary.
182+
"""
183+
serialized_outputs = {}
184+
for key, config in outputs_to_state.items():
185+
serialized_config = config.copy()
186+
if "handler" in config:
187+
serialized_config["handler"] = serialize_callable(config["handler"])
188+
serialized_outputs[key] = serialized_config
189+
return serialized_outputs
190+
191+
192+
def _deserialize_outputs_to_state(outputs_to_state: dict[str, dict[str, Any]]) -> dict[str, dict[str, Any]]:
193+
"""
194+
Deserializes the outputs_to_state dictionary, converting any string handlers back to callables.
195+
196+
:param outputs_to_state: The outputs_to_state dictionary to deserialize.
197+
:returns: The deserialized outputs_to_state dictionary.
198+
"""
199+
deserialized_outputs = {}
200+
for key, config in outputs_to_state.items():
201+
deserialized_config = config.copy()
202+
if "handler" in config:
203+
deserialized_config["handler"] = deserialize_callable(config["handler"])
204+
deserialized_outputs[key] = deserialized_config
205+
return deserialized_outputs

haystack/utils/callable_serialization.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import Any, Callable
77

88
from haystack.core.errors import DeserializationError, SerializationError
9-
from haystack.tools.tool import Tool
109
from haystack.utils.type_serialization import thread_safe_import
1110

1211

@@ -51,6 +50,9 @@ def deserialize_callable(callable_handle: str) -> Callable:
5150
:return: The callable
5251
:raises DeserializationError: If the callable cannot be found
5352
"""
53+
# Import here to avoid circular imports
54+
from haystack.tools.tool import Tool
55+
5456
parts = callable_handle.split(".")
5557

5658
for i in range(len(parts), 0, -1):

0 commit comments

Comments
 (0)