Skip to content

Commit fb0fa28

Browse files
committed
Run Task failure callbacks on DAG Processor when task is externally killed
Until #44354 is implemented, tasks killed externally or when supervisor process dies unexpectedly, users have no way of knowing this happened. This has been a blocker for Airflow 3.0 adoption for some: - #44354 - https://apache-airflow.slack.com/archives/C07813CNKA8/p1751057525231389 #44354 is more involved and we might not get to it for Airflow 3.1 -- so this is a good fix until then similar to how we run Dag Run callback.
1 parent 39624dc commit fb0fa28

File tree

8 files changed

+464
-60
lines changed

8 files changed

+464
-60
lines changed

airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ class TIRunContext(BaseModel):
327327
xcom_keys_to_clear: Annotated[list[str], Field(default_factory=list)]
328328
"""List of Xcom keys that need to be cleared and purged on by the worker."""
329329

330-
should_retry: bool
330+
should_retry: bool = False
331331
"""If the ti encounters an error, whether it should enter retry or failed state."""
332332

333333

airflow-core/src/airflow/callbacks/callback_requests.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ class TaskCallbackRequest(BaseCallbackRequest):
6161
"""Simplified Task Instance representation"""
6262
task_callback_type: TaskInstanceState | None = None
6363
"""Whether on success, on failure, on retry"""
64+
context_from_server: ti_datamodel.TIRunContext | None = None
65+
"""Task execution context from the Server"""
6466
type: Literal["TaskCallbackRequest"] = "TaskCallbackRequest"
6567

6668
@property

airflow-core/src/airflow/dag_processing/processor.py

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
import contextlib
1920
import importlib
2021
import os
2122
import sys
2223
import traceback
23-
from collections.abc import Callable
24+
from collections.abc import Callable, Sequence
2425
from pathlib import Path
2526
from typing import TYPE_CHECKING, Annotated, BinaryIO, ClassVar, Literal
2627

@@ -45,9 +46,11 @@
4546
VariableResult,
4647
)
4748
from airflow.sdk.execution_time.supervisor import WatchedSubprocess
49+
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
4850
from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG
4951
from airflow.stats import Stats
5052
from airflow.utils.file import iter_airflow_imports
53+
from airflow.utils.state import TaskInstanceState
5154

5255
if TYPE_CHECKING:
5356
from structlog.typing import FilteringBoundLogger
@@ -201,10 +204,7 @@ def _execute_callbacks(
201204
for request in callback_requests:
202205
log.debug("Processing Callback Request", request=request.to_json())
203206
if isinstance(request, TaskCallbackRequest):
204-
raise NotImplementedError(
205-
"Haven't coded Task callback yet - https://github.com/apache/airflow/issues/44354!"
206-
)
207-
# _execute_task_callbacks(dagbag, request)
207+
_execute_task_callbacks(dagbag, request, log)
208208
if isinstance(request, DagCallbackRequest):
209209
_execute_dag_callbacks(dagbag, request, log)
210210

@@ -238,6 +238,66 @@ def _execute_dag_callbacks(dagbag: DagBag, request: DagCallbackRequest, log: Fil
238238
Stats.incr("dag.callback_exceptions", tags={"dag_id": request.dag_id})
239239

240240

241+
def _execute_task_callbacks(dagbag: DagBag, request: TaskCallbackRequest, log: FilteringBoundLogger) -> None:
242+
if not request.is_failure_callback:
243+
log.warning(
244+
"Task callback requested but is not a failure callback",
245+
dag_id=request.ti.dag_id,
246+
task_id=request.ti.task_id,
247+
run_id=request.ti.run_id,
248+
)
249+
return
250+
251+
dag = dagbag.dags[request.ti.dag_id]
252+
task = dag.get_task(request.ti.task_id)
253+
254+
if request.task_callback_type is TaskInstanceState.UP_FOR_RETRY:
255+
callbacks = task.on_retry_callback
256+
else:
257+
callbacks = task.on_failure_callback
258+
259+
if not callbacks:
260+
log.warning(
261+
"Callback requested but no callback found",
262+
dag_id=request.ti.dag_id,
263+
task_id=request.ti.task_id,
264+
run_id=request.ti.run_id,
265+
)
266+
return
267+
268+
callbacks = callbacks if isinstance(callbacks, Sequence) else [callbacks]
269+
ctx_from_server = request.context_from_server
270+
271+
if ctx_from_server is not None:
272+
rti = RuntimeTaskInstance.model_construct(
273+
**request.ti.model_dump(exclude_unset=True),
274+
task=task,
275+
_ti_context_from_server=ctx_from_server,
276+
max_tries=ctx_from_server.max_tries,
277+
)
278+
else:
279+
rti = RuntimeTaskInstance.model_construct(
280+
**request.ti.model_dump(exclude_unset=True),
281+
task=task,
282+
)
283+
context = rti.get_template_context()
284+
285+
def get_callback_representation(callback):
286+
with contextlib.suppress(AttributeError):
287+
return callback.__name__
288+
with contextlib.suppress(AttributeError):
289+
return callback.__class__.__name__
290+
return callback
291+
292+
for idx, callback in enumerate(callbacks):
293+
callback_repr = get_callback_representation(callback)
294+
log.info("Executing Task callback at index %d: %s", idx, callback_repr)
295+
try:
296+
callback(context)
297+
except Exception:
298+
log.exception("Error in callback at index %d: %s", idx, callback_repr)
299+
300+
241301
def in_process_api_server() -> InProcessExecutionAPI:
242302
from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI
243303

airflow-core/src/airflow/jobs/scheduler_job_runner.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from sqlalchemy.sql import expression
3939

4040
from airflow import settings
41+
from airflow.api_fastapi.execution_api.datamodels.taskinstance import TIRunContext
4142
from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallbackRequest
4243
from airflow.configuration import conf
4344
from airflow.dag_processing.bundles.base import BundleUsageTrackingManager
@@ -927,10 +928,13 @@ def process_executor_events(
927928
bundle_version=ti.dag_version.bundle_version,
928929
ti=ti,
929930
msg=msg,
931+
context_from_server=TIRunContext(
932+
dag_run=ti.dag_run,
933+
max_tries=ti.max_tries,
934+
),
930935
)
931936
executor.send_callback(request)
932-
else:
933-
ti.handle_failure(error=msg, session=session)
937+
ti.handle_failure(error=msg, session=session)
934938

935939
return len(event_buffer)
936940

@@ -2283,6 +2287,10 @@ def _purge_task_instances_without_heartbeats(
22832287
bundle_version=ti.dag_run.bundle_version,
22842288
ti=ti,
22852289
msg=str(task_instance_heartbeat_timeout_message_details),
2290+
context_from_server=TIRunContext(
2291+
dag_run=ti.dag_run,
2292+
max_tries=ti.max_tries,
2293+
),
22862294
)
22872295
session.add(
22882296
Log(

airflow-core/tests/unit/callbacks/test_callback_requests.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from airflow.models.taskinstance import TaskInstance
3030
from airflow.providers.standard.operators.bash import BashOperator
3131
from airflow.utils import timezone
32-
from airflow.utils.state import State
32+
from airflow.utils.state import State, TaskInstanceState
3333

3434
pytestmark = pytest.mark.db_test
3535

@@ -87,3 +87,30 @@ def test_taskcallback_to_json_with_start_date_and_end_date(self, session, create
8787
json_str = input.to_json()
8888
result = TaskCallbackRequest.from_json(json_str)
8989
assert input == result
90+
91+
@pytest.mark.parametrize(
92+
"task_callback_type,expected_is_failure",
93+
[
94+
(None, True),
95+
(TaskInstanceState.FAILED, True),
96+
(TaskInstanceState.UP_FOR_RETRY, True),
97+
(TaskInstanceState.UPSTREAM_FAILED, True),
98+
(TaskInstanceState.SUCCESS, False),
99+
(TaskInstanceState.RUNNING, False),
100+
],
101+
)
102+
def test_is_failure_callback_property(
103+
self, task_callback_type, expected_is_failure, create_task_instance
104+
):
105+
"""Test is_failure_callback property with different task callback types"""
106+
ti = create_task_instance()
107+
108+
request = TaskCallbackRequest(
109+
filepath="filepath",
110+
ti=ti,
111+
bundle_name="testing",
112+
bundle_version=None,
113+
task_callback_type=task_callback_type,
114+
)
115+
116+
assert request.is_failure_callback == expected_is_failure

0 commit comments

Comments
 (0)