Skip to content

Commit 6c7ae8f

Browse files
davidsbatistasjrl
andauthored
feat: pipeline checkpoints crash allow to resume from snapshot (#9743)
* initial PoC idea running * removing test code * cleaning up * wip * cleaning up demos * adding more pipelines to test persistence saving * wip * wip * working example for logging components inputs in run time * reverting to a simpler solution for intermediate results * cleaning up * testing that in a crash components outputs/inputs up to the crash point are returned * adding tests for state persistance in a RAG pipeline * updataing tests for state persistance in a RAG pipeline * removing use cases of agent tests * adding LICENSE header * adding LICENSE header * adding release notes * updating tests for mocked components only * updating release notes * adapting PipelineRuntimeError * cleaning up tests * fixing test pipeline crash components inputs/outputs are saved * fixing tests for state persistance * removing code * removing code * removing code * updating release notes * validating parameters * cleaning * wip: debugging * removing persistance tests * formatting * formatting * cleaning up code * updating release notes * adding missing docstrings * typo in release notes * Update haystack/core/pipeline/pipeline.py Co-authored-by: Sebastian Husch Lee <[email protected]> * PR comments * handling potential issues with saving the snapshot file * updating tests * Update haystack/core/pipeline/pipeline.py Co-authored-by: Sebastian Husch Lee <[email protected]> * Update haystack/core/pipeline/pipeline.py Co-authored-by: Sebastian Husch Lee <[email protected]> * some more improvements * fixing exxception * fixing exception error name conflict --------- Co-authored-by: Sebastian Husch Lee <[email protected]>
1 parent 477188d commit 6c7ae8f

File tree

4 files changed

+79
-0
lines changed

4 files changed

+79
-0
lines changed

haystack/core/pipeline/pipeline.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818
from haystack.core.pipeline.breakpoint import (
1919
_create_pipeline_snapshot,
20+
_save_pipeline_snapshot,
2021
_trigger_break_point,
2122
_validate_break_point_against_pipeline,
2223
_validate_pipeline_snapshot_against_pipeline,
@@ -25,6 +26,7 @@
2526
from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, PipelineSnapshot
2627
from haystack.telemetry import pipeline_running
2728
from haystack.utils import _deserialize_value_with_schema
29+
from haystack.utils.misc import _get_output_dir
2830

2931
logger = logging.getLogger(__name__)
3032

@@ -390,6 +392,36 @@ def run( # noqa: PLR0915, PLR0912, C901, pylint: disable=too-many-branches
390392
except PipelineRuntimeError as error:
391393
# Attach partial pipeline outputs to the error before re-raising
392394
error.pipeline_outputs = pipeline_outputs
395+
396+
# Create a snapshot of the last good state of the pipeline before the error occurred.
397+
pipeline_snapshot_inputs_serialised = deepcopy(inputs)
398+
pipeline_snapshot_inputs_serialised[component_name] = deepcopy(component_inputs)
399+
out_dir = _get_output_dir("pipeline_snapshot")
400+
break_point = Breakpoint(
401+
component_name=component_name,
402+
visit_count=component_visits[component_name],
403+
snapshot_file_path=out_dir,
404+
)
405+
last_good_state_snapshot = _create_pipeline_snapshot(
406+
inputs=pipeline_snapshot_inputs_serialised,
407+
break_point=break_point,
408+
component_visits=component_visits,
409+
original_input_data=data,
410+
ordered_component_names=ordered_component_names,
411+
include_outputs_from=include_outputs_from,
412+
pipeline_outputs=pipeline_outputs,
413+
)
414+
try:
415+
_save_pipeline_snapshot(pipeline_snapshot=last_good_state_snapshot)
416+
logger.info(
417+
"Saved a snapshot of the pipeline's last valid state to '{out_path}'. "
418+
"Review this snapshot to debug the error and resume the pipeline from here.",
419+
out_path=out_dir,
420+
)
421+
except Exception as save_error:
422+
logger.error(
423+
"Failed to save a snapshot of the pipeline's last valid state with error: {e}", e=save_error
424+
)
393425
raise error
394426

395427
# Updates global input state with component outputs and returns outputs that should go to

haystack/utils/misc.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import mimetypes
6+
import tempfile
67
from pathlib import Path
78
from typing import Any, Optional, Union, overload
89

@@ -81,3 +82,36 @@ def _guess_mime_type(path: Path) -> Optional[str]:
8182
mime_type = mimetypes.guess_type(path.as_posix())[0]
8283
# lookup custom mappings if the mime type is not found
8384
return CUSTOM_MIMETYPES.get(extension, mime_type)
85+
86+
87+
def _get_output_dir(out_dir: str) -> str:
88+
"""
89+
Find or create a writable directory for saving status files.
90+
91+
Tries in the following order:
92+
93+
1. ~/.haystack/{out_dir}
94+
2. {tempdir}/haystack/{out_dir}
95+
3. ./.haystack/{out_dir}
96+
97+
:raises RuntimeError: If no directory could be created.
98+
:returns:
99+
The path to the created directory.
100+
"""
101+
102+
candidates = [
103+
Path.home() / ".haystack" / out_dir,
104+
Path(tempfile.gettempdir()) / "haystack" / out_dir,
105+
Path.cwd() / ".haystack" / out_dir,
106+
]
107+
108+
for candidate in candidates:
109+
try:
110+
candidate.mkdir(parents=True, exist_ok=True)
111+
return str(candidate)
112+
except Exception:
113+
continue
114+
115+
raise RuntimeError(
116+
f"Could not create a writable directory for output files in any of the following locations: {candidates}"
117+
)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
features:
3+
- |
4+
A snapshot of the last successful step is now saved when an error occurs during a `Pipeline` run. This allows you
5+
to inspect the snapshot, potentially identify and fix the error, and later resume the pipeline from that point
6+
onwards. Avoiding to re-run the entire pipeline from the start. Currently, only available in `Pipeline` and not yet
7+
in `AsyncPipeline`.

test/core/pipeline/test_pipeline_crash_regular_pipeline_outputs_raised.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import os
56
from unittest.mock import MagicMock, patch
67

78
import numpy as np
@@ -21,6 +22,7 @@
2122
from haystack.document_stores.in_memory import InMemoryDocumentStore
2223
from haystack.document_stores.types import DuplicatePolicy
2324
from haystack.utils.auth import Secret
25+
from haystack.utils.misc import _get_output_dir
2426

2527

2628
def setup_document_store():
@@ -232,3 +234,7 @@ async def test_async_hybrid_rag_pipeline_crash_on_embedding_retriever(
232234
assert "prompt_builder" not in pipeline_outputs, "Prompt builder should not have run due to crash"
233235
assert "llm" not in pipeline_outputs, "LLM should not have run due to crash"
234236
assert "answer_builder" not in pipeline_outputs, "Answer builder should not have run due to crash"
237+
238+
# check that a pipeline snapshot file was created in the "pipeline_snapshot" directory
239+
snapshot_files = os.listdir(_get_output_dir("pipeline_snapshot"))
240+
assert any(f.endswith(".json") for f in snapshot_files), "No pipeline snapshot file found in debug directory"

0 commit comments

Comments
 (0)