Skip to content

Commit ac6a43f

Browse files
davidsbatistasjrl
andauthored
feat: raise components inputs/outputs during execution if an Exception occurs (#9742)
* initial PoC idea running * removing test code * cleaning up * wip * cleaning up demos * adding more pipelines to test persistence saving * wip * wip * working example for logging components inputs in run time * reverting to a simpler solution for intermediate results * cleaning up * testing that in a crash components outputs/inputs up to the crash point are returned * adding tests for state persistance in a RAG pipeline * updataing tests for state persistance in a RAG pipeline * removing use cases of agent tests * adding LICENSE header * adding LICENSE header * adding release notes * updating tests for mocked components only * updating release notes * adapting PipelineRuntimeError * cleaning up tests * fixing test pipeline crash components inputs/outputs are saved * fixing tests for state persistance * isolating changes * cleaning * updating release notes * addding test for regular pipeline * small improvements and updating release notes * cleaning imports * removing code * improvements/fixes based on PR comments * raising pipeline_outputs on async version of Pipeline * fixing async versions + updating tests * simplifying tests * Suggested changes pipeline crash (#9744) * Suggested changes * Some cleanup * Small changes --------- Co-authored-by: Sebastian Husch Lee <[email protected]>
1 parent 4275fed commit ac6a43f

File tree

6 files changed

+296
-26
lines changed

6 files changed

+296
-26
lines changed

haystack/core/errors.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,16 @@ class PipelineError(Exception):
1010

1111

1212
class PipelineRuntimeError(Exception):
13-
def __init__(self, component_name: Optional[str], component_type: Optional[type], message: str) -> None:
13+
def __init__(
14+
self,
15+
component_name: Optional[str],
16+
component_type: Optional[type],
17+
message: str,
18+
pipeline_outputs: Optional[Any] = None,
19+
) -> None:
1420
self.component_name = component_name
1521
self.component_type = component_type
22+
self.pipeline_outputs = pipeline_outputs
1623
super().__init__(message)
1724

1825
@classmethod

haystack/core/pipeline/async_pipeline.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,12 @@ async def _run_component_async(
7171
# Important: contextvars (e.g. active tracing Span) don’t propagate to running loop's ThreadPoolExecutor
7272
# We use ctx.run(...) to preserve context like the active tracing span
7373
ctx = contextvars.copy_context()
74-
outputs = await loop.run_in_executor(None, lambda: ctx.run(lambda: instance.run(**component_inputs)))
74+
try:
75+
outputs = await loop.run_in_executor(
76+
None, lambda: ctx.run(lambda: instance.run(**component_inputs))
77+
)
78+
except Exception as error:
79+
raise PipelineRuntimeError.from_exception(component_name, instance.__class__, error) from error
7580

7681
component_visits[component_name] += 1
7782

@@ -256,13 +261,19 @@ async def _run_highest_in_isolation(component_name: str) -> AsyncIterator[dict[s
256261
)
257262
component_inputs = self._consume_component_inputs(component_name, comp_dict, inputs_state)
258263
component_inputs = self._add_missing_input_defaults(component_inputs, comp_dict["input_sockets"])
259-
component_pipeline_outputs = await self._run_component_async(
260-
component_name=component_name,
261-
component=comp_dict,
262-
component_inputs=component_inputs,
263-
component_visits=component_visits,
264-
parent_span=parent_span,
265-
)
264+
265+
try:
266+
component_pipeline_outputs = await self._run_component_async(
267+
component_name=component_name,
268+
component=comp_dict,
269+
component_inputs=component_inputs,
270+
component_visits=component_visits,
271+
parent_span=parent_span,
272+
)
273+
except PipelineRuntimeError as error:
274+
# Attach partial pipeline outputs to the error before re-raising
275+
error.pipeline_outputs = pipeline_outputs
276+
raise error
266277

267278
# Distribute outputs to downstream inputs; also prune outputs based on `include_outputs_from`
268279
pruned = self._write_component_outputs(
@@ -300,14 +311,19 @@ async def _schedule_task(component_name: str) -> None:
300311
component_inputs = self._add_missing_input_defaults(component_inputs, comp_dict["input_sockets"])
301312

302313
async def _runner():
303-
async with ready_sem:
304-
component_pipeline_outputs = await self._run_component_async(
305-
component_name=component_name,
306-
component=comp_dict,
307-
component_inputs=component_inputs,
308-
component_visits=component_visits,
309-
parent_span=parent_span,
310-
)
314+
try:
315+
async with ready_sem:
316+
component_pipeline_outputs = await self._run_component_async(
317+
component_name=component_name,
318+
component=comp_dict,
319+
component_inputs=component_inputs,
320+
component_visits=component_visits,
321+
parent_span=parent_span,
322+
)
323+
except PipelineRuntimeError as error:
324+
# Attach partial pipeline outputs to the error before re-raising
325+
error.pipeline_outputs = pipeline_outputs
326+
raise error
311327

312328
# Distribute outputs to downstream inputs; also prune outputs based on `include_outputs_from`
313329
pruned = self._write_component_outputs(

haystack/core/pipeline/breakpoint.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,12 @@ def _save_pipeline_snapshot_to_file(
163163
if isinstance(pipeline_snapshot.break_point, AgentBreakpoint):
164164
agent_name = pipeline_snapshot.break_point.agent_name
165165
component_name = pipeline_snapshot.break_point.break_point.component_name
166-
file_name = f"{agent_name}_{component_name}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json"
166+
visit_nr = pipeline_snapshot.pipeline_state.component_visits.get(component_name, 0)
167+
file_name = f"{agent_name}_{component_name}_{visit_nr}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json"
167168
else:
168169
component_name = pipeline_snapshot.break_point.component_name
169-
file_name = f"{component_name}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json"
170+
visit_nr = pipeline_snapshot.pipeline_state.component_visits.get(component_name, 0)
171+
file_name = f"{component_name}_{visit_nr}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json"
170172

171173
try:
172174
with open(snapshot_file_path / file_name, "w") as f_out:

haystack/core/pipeline/pipeline.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -379,13 +379,18 @@ def run( # noqa: PLR0915, PLR0912, C901, pylint: disable=too-many-branches
379379
pipeline_snapshot=new_pipeline_snapshot, pipeline_outputs=pipeline_outputs
380380
)
381381

382-
component_outputs = self._run_component(
383-
component_name=component_name,
384-
component=component,
385-
inputs=component_inputs, # the inputs to the current component
386-
component_visits=component_visits,
387-
parent_span=span,
388-
)
382+
try:
383+
component_outputs = self._run_component(
384+
component_name=component_name,
385+
component=component,
386+
inputs=component_inputs, # the inputs to the current component
387+
component_visits=component_visits,
388+
parent_span=span,
389+
)
390+
except PipelineRuntimeError as error:
391+
# Attach partial pipeline outputs to the error before re-raising
392+
error.pipeline_outputs = pipeline_outputs
393+
raise error
389394

390395
# Updates global input state with component outputs and returns outputs that should go to
391396
# pipeline outputs.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
features:
3+
- |
4+
If an error occurs during the execution of a pipeline, the pipeline will raise an PipelineRuntimeError exception
5+
containing an error message and the components outputs up to the point of failure. This allows you to inspect and
6+
debug the pipeline up to the point of failure.
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from unittest.mock import MagicMock, patch
6+
7+
import numpy as np
8+
import pytest
9+
10+
from haystack import AsyncPipeline, Document, Pipeline
11+
from haystack.components.builders import ChatPromptBuilder
12+
from haystack.components.builders.answer_builder import AnswerBuilder
13+
from haystack.components.embedders import SentenceTransformersTextEmbedder
14+
from haystack.components.generators.chat import OpenAIChatGenerator
15+
from haystack.components.joiners import DocumentJoiner
16+
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
17+
from haystack.components.writers import DocumentWriter
18+
from haystack.core.component import component
19+
from haystack.core.errors import PipelineRuntimeError
20+
from haystack.dataclasses import ChatMessage
21+
from haystack.document_stores.in_memory import InMemoryDocumentStore
22+
from haystack.document_stores.types import DuplicatePolicy
23+
from haystack.utils.auth import Secret
24+
25+
26+
def setup_document_store():
27+
"""Create and populate a document store with test documents."""
28+
documents = [
29+
Document(content="My name is Jean and I live in Paris.", embedding=[0.1, 0.3, 0.6]),
30+
Document(content="My name is Mark and I live in Berlin.", embedding=[0.2, 0.4, 0.7]),
31+
Document(content="My name is Giorgio and I live in Rome.", embedding=[0.3, 0.5, 0.8]),
32+
]
33+
34+
document_store = InMemoryDocumentStore()
35+
doc_writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP)
36+
doc_writer.run(documents=documents)
37+
38+
return document_store
39+
40+
41+
# Create a mock component that returns invalid output (int instead of documents list)
42+
@component
43+
class InvalidOutputEmbeddingRetriever:
44+
@component.output_types(documents=list[Document])
45+
def run(self, query_embedding: list[float]):
46+
# Return an int instead of the expected documents list
47+
# This will cause the pipeline to crash when trying to pass it to the next component
48+
return 42
49+
50+
51+
template = [
52+
ChatMessage.from_system(
53+
"You are a helpful AI assistant. Answer the following question based on the given context information "
54+
"only. If the context is empty or just a '\n' answer with None, example: 'None'."
55+
),
56+
ChatMessage.from_user(
57+
"""
58+
Context:
59+
{% for document in documents %}
60+
{{ document.content }}
61+
{% endfor %}
62+
63+
Question: {{question}}
64+
"""
65+
),
66+
]
67+
68+
69+
class TestPipelineOutputsRaisedInException:
70+
@pytest.fixture
71+
def mock_sentence_transformers_text_embedder(self):
72+
with patch(
73+
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
74+
) as mock_text_embedder:
75+
mock_model = MagicMock()
76+
mock_text_embedder.return_value = mock_model
77+
78+
def mock_encode(
79+
texts, batch_size=None, show_progress_bar=None, normalize_embeddings=None, precision=None, **kwargs
80+
): # noqa E501
81+
return [np.ones(384).tolist() for _ in texts]
82+
83+
mock_model.encode = mock_encode
84+
embedder = SentenceTransformersTextEmbedder(model="mock-model", progress_bar=False)
85+
86+
def mock_run(text):
87+
if not isinstance(text, str):
88+
raise TypeError(
89+
"SentenceTransformersTextEmbedder expects a string as input."
90+
"In case you want to embed a list of Documents, please use the "
91+
"SentenceTransformersDocumentEmbedder."
92+
)
93+
94+
embedding = np.ones(384).tolist()
95+
return {"embedding": embedding}
96+
97+
embedder.run = mock_run
98+
embedder.warm_up()
99+
return embedder
100+
101+
def test_hybrid_rag_pipeline_crash_on_embedding_retriever(
102+
self, mock_sentence_transformers_text_embedder, monkeypatch
103+
):
104+
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
105+
106+
document_store = setup_document_store()
107+
text_embedder = mock_sentence_transformers_text_embedder
108+
invalid_embedding_retriever = InvalidOutputEmbeddingRetriever()
109+
bm25_retriever = InMemoryBM25Retriever(document_store)
110+
document_joiner = DocumentJoiner(join_mode="concatenate")
111+
112+
pipeline = Pipeline()
113+
pipeline.add_component("text_embedder", text_embedder)
114+
pipeline.add_component("embedding_retriever", invalid_embedding_retriever)
115+
pipeline.add_component("bm25_retriever", bm25_retriever)
116+
pipeline.add_component("document_joiner", document_joiner)
117+
pipeline.add_component(
118+
"prompt_builder", ChatPromptBuilder(template=template, required_variables=["question", "documents"])
119+
)
120+
pipeline.add_component("llm", OpenAIChatGenerator(api_key=Secret.from_env_var("OPENAI_API_KEY")))
121+
pipeline.add_component("answer_builder", AnswerBuilder())
122+
123+
pipeline.connect("text_embedder", "embedding_retriever")
124+
pipeline.connect("bm25_retriever", "document_joiner")
125+
pipeline.connect("embedding_retriever", "document_joiner")
126+
pipeline.connect("document_joiner.documents", "prompt_builder.documents")
127+
pipeline.connect("prompt_builder", "llm")
128+
pipeline.connect("llm.replies", "answer_builder.replies")
129+
130+
question = "Where does Mark live?"
131+
test_data = {
132+
"text_embedder": {"text": question},
133+
"bm25_retriever": {"query": question},
134+
"prompt_builder": {"question": question},
135+
"answer_builder": {"query": question},
136+
}
137+
138+
# run pipeline and expect it to crash due to invalid output type
139+
with pytest.raises(PipelineRuntimeError) as exc_info:
140+
pipeline.run(
141+
data=test_data,
142+
include_outputs_from={
143+
"text_embedder",
144+
"embedding_retriever",
145+
"bm25_retriever",
146+
"document_joiner",
147+
"prompt_builder",
148+
"llm",
149+
"answer_builder",
150+
},
151+
)
152+
153+
pipeline_outputs = exc_info.value.pipeline_outputs
154+
155+
assert pipeline_outputs is not None, "Pipeline outputs should be captured in the exception"
156+
157+
# verify that bm25_retriever and text_embedder ran successfully before the crash
158+
assert "bm25_retriever" in pipeline_outputs, "BM25 retriever output not captured"
159+
assert "documents" in pipeline_outputs["bm25_retriever"], "BM25 retriever should have produced documents"
160+
assert "text_embedder" in pipeline_outputs, "Text embedder output not captured"
161+
assert "embedding" in pipeline_outputs["text_embedder"], "Text embedder should have produced embeddings"
162+
163+
# components after the crash point are not in the outputs
164+
assert "document_joiner" not in pipeline_outputs, "Document joiner should not have run due to crash"
165+
assert "prompt_builder" not in pipeline_outputs, "Prompt builder should not have run due to crash"
166+
assert "llm" not in pipeline_outputs, "LLM should not have run due to crash"
167+
assert "answer_builder" not in pipeline_outputs, "Answer builder should not have run due to crash"
168+
169+
@pytest.mark.asyncio
170+
async def test_async_hybrid_rag_pipeline_crash_on_embedding_retriever(
171+
self, mock_sentence_transformers_text_embedder, monkeypatch
172+
):
173+
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
174+
175+
document_store = setup_document_store()
176+
text_embedder = mock_sentence_transformers_text_embedder
177+
invalid_embedding_retriever = InvalidOutputEmbeddingRetriever()
178+
bm25_retriever = InMemoryBM25Retriever(document_store)
179+
document_joiner = DocumentJoiner(join_mode="concatenate")
180+
181+
pipeline = AsyncPipeline()
182+
pipeline.add_component("text_embedder", text_embedder)
183+
pipeline.add_component("embedding_retriever", invalid_embedding_retriever)
184+
pipeline.add_component("bm25_retriever", bm25_retriever)
185+
pipeline.add_component("document_joiner", document_joiner)
186+
pipeline.add_component(
187+
"prompt_builder", ChatPromptBuilder(template=template, required_variables=["question", "documents"])
188+
)
189+
pipeline.add_component("llm", OpenAIChatGenerator(api_key=Secret.from_env_var("OPENAI_API_KEY")))
190+
pipeline.add_component("answer_builder", AnswerBuilder())
191+
192+
pipeline.connect("text_embedder", "embedding_retriever")
193+
pipeline.connect("bm25_retriever", "document_joiner")
194+
pipeline.connect("embedding_retriever", "document_joiner")
195+
pipeline.connect("document_joiner.documents", "prompt_builder.documents")
196+
pipeline.connect("prompt_builder", "llm")
197+
pipeline.connect("llm.replies", "answer_builder.replies")
198+
199+
question = "Where does Mark live?"
200+
test_data = {
201+
"text_embedder": {"text": question},
202+
"bm25_retriever": {"query": question},
203+
"prompt_builder": {"question": question},
204+
"answer_builder": {"query": question},
205+
}
206+
207+
with pytest.raises(PipelineRuntimeError) as exc_info:
208+
await pipeline.run_async(
209+
data=test_data,
210+
include_outputs_from={
211+
"text_embedder",
212+
"embedding_retriever",
213+
"bm25_retriever",
214+
"document_joiner",
215+
"prompt_builder",
216+
"llm",
217+
"answer_builder",
218+
},
219+
)
220+
221+
pipeline_outputs = exc_info.value.pipeline_outputs
222+
assert pipeline_outputs is not None, "Pipeline outputs should be captured in the exception"
223+
224+
# verify that bm25_retriever and text_embedder ran successfully before the crash
225+
assert "bm25_retriever" in pipeline_outputs, "BM25 retriever output not captured"
226+
assert "documents" in pipeline_outputs["bm25_retriever"], "BM25 retriever should have produced documents"
227+
assert "text_embedder" in pipeline_outputs, "Text embedder output not captured"
228+
assert "embedding" in pipeline_outputs["text_embedder"], "Text embedder should have produced embeddings"
229+
230+
# components after the crash point are not in the outputs
231+
assert "document_joiner" not in pipeline_outputs, "Document joiner should not have run due to crash"
232+
assert "prompt_builder" not in pipeline_outputs, "Prompt builder should not have run due to crash"
233+
assert "llm" not in pipeline_outputs, "LLM should not have run due to crash"
234+
assert "answer_builder" not in pipeline_outputs, "Answer builder should not have run due to crash"

0 commit comments

Comments
 (0)