Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ class TIRunContext(BaseModel):
dag_run: DagRun
"""DAG run information for the task instance."""

task_reschedule_count: Annotated[int, Field(default=0)]
task_reschedule_count: int = 0
"""How many times the task has been rescheduled."""

max_tries: int
Expand All @@ -327,7 +327,7 @@ class TIRunContext(BaseModel):
xcom_keys_to_clear: Annotated[list[str], Field(default_factory=list)]
"""List of Xcom keys that need to be cleared and purged on by the worker."""

should_retry: bool
should_retry: bool = False
"""If the ti encounters an error, whether it should enter retry or failed state."""


Expand Down
2 changes: 2 additions & 0 deletions airflow-core/src/airflow/callbacks/callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class TaskCallbackRequest(BaseCallbackRequest):
"""Simplified Task Instance representation"""
task_callback_type: TaskInstanceState | None = None
"""Whether on success, on failure, on retry"""
context_from_server: ti_datamodel.TIRunContext | None = None
"""Task execution context from the Server"""
type: Literal["TaskCallbackRequest"] = "TaskCallbackRequest"

@property
Expand Down
71 changes: 66 additions & 5 deletions airflow-core/src/airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
# under the License.
from __future__ import annotations

import contextlib
import importlib
import os
import sys
import traceback
from collections.abc import Callable
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, BinaryIO, ClassVar, Literal

Expand All @@ -45,9 +46,11 @@
VariableResult,
)
from airflow.sdk.execution_time.supervisor import WatchedSubprocess
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG
from airflow.stats import Stats
from airflow.utils.file import iter_airflow_imports
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger
Expand Down Expand Up @@ -201,10 +204,7 @@ def _execute_callbacks(
for request in callback_requests:
log.debug("Processing Callback Request", request=request.to_json())
if isinstance(request, TaskCallbackRequest):
raise NotImplementedError(
"Haven't coded Task callback yet - https://github.com/apache/airflow/issues/44354!"
)
# _execute_task_callbacks(dagbag, request)
_execute_task_callbacks(dagbag, request, log)
if isinstance(request, DagCallbackRequest):
_execute_dag_callbacks(dagbag, request, log)

Expand Down Expand Up @@ -238,6 +238,67 @@ def _execute_dag_callbacks(dagbag: DagBag, request: DagCallbackRequest, log: Fil
Stats.incr("dag.callback_exceptions", tags={"dag_id": request.dag_id})


def _execute_task_callbacks(dagbag: DagBag, request: TaskCallbackRequest, log: FilteringBoundLogger) -> None:
if not request.is_failure_callback:
log.warning(
"Task callback requested but is not a failure callback",
dag_id=request.ti.dag_id,
task_id=request.ti.task_id,
run_id=request.ti.run_id,
)
return

dag = dagbag.dags[request.ti.dag_id]
task = dag.get_task(request.ti.task_id)

if request.task_callback_type is TaskInstanceState.UP_FOR_RETRY:
callbacks = task.on_retry_callback
else:
callbacks = task.on_failure_callback

if not callbacks:
log.warning(
"Callback requested but no callback found",
dag_id=request.ti.dag_id,
task_id=request.ti.task_id,
run_id=request.ti.run_id,
ti_id=request.ti.id,
)
return

callbacks = callbacks if isinstance(callbacks, Sequence) else [callbacks]
ctx_from_server = request.context_from_server

if ctx_from_server is not None:
runtime_ti = RuntimeTaskInstance.model_construct(
**request.ti.model_dump(exclude_unset=True),
task=task,
_ti_context_from_server=ctx_from_server,
max_tries=ctx_from_server.max_tries,
)
else:
runtime_ti = RuntimeTaskInstance.model_construct(
**request.ti.model_dump(exclude_unset=True),
task=task,
)
context = runtime_ti.get_template_context()

def get_callback_representation(callback):
with contextlib.suppress(AttributeError):
return callback.__name__
with contextlib.suppress(AttributeError):
return callback.__class__.__name__
return callback

for idx, callback in enumerate(callbacks):
callback_repr = get_callback_representation(callback)
log.info("Executing Task callback at index %d: %s", idx, callback_repr)
try:
callback(context)
except Exception:
log.exception("Error in callback at index %d: %s", idx, callback_repr)


def in_process_api_server() -> InProcessExecutionAPI:
from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI

Expand Down
18 changes: 16 additions & 2 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from sqlalchemy.sql import expression

from airflow import settings
from airflow.api_fastapi.execution_api.datamodels.taskinstance import TIRunContext
from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallbackRequest
from airflow.configuration import conf
from airflow.dag_processing.bundles.base import BundleUsageTrackingManager
Expand Down Expand Up @@ -927,10 +928,16 @@ def process_executor_events(
bundle_version=ti.dag_version.bundle_version,
ti=ti,
msg=msg,
context_from_server=TIRunContext(
dag_run=ti.dag_run,
max_tries=ti.max_tries,
variables=[],
connections=[],
xcom_keys_to_clear=[],
),
)
executor.send_callback(request)
else:
ti.handle_failure(error=msg, session=session)
ti.handle_failure(error=msg, session=session)

return len(event_buffer)

Expand Down Expand Up @@ -2283,6 +2290,13 @@ def _purge_task_instances_without_heartbeats(
bundle_version=ti.dag_run.bundle_version,
ti=ti,
msg=str(task_instance_heartbeat_timeout_message_details),
context_from_server=TIRunContext(
dag_run=ti.dag_run,
max_tries=ti.max_tries,
variables=[],
connections=[],
xcom_keys_to_clear=[],
),
)
session.add(
Log(
Expand Down
29 changes: 28 additions & 1 deletion airflow-core/tests/unit/callbacks/test_callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from airflow.models.taskinstance import TaskInstance
from airflow.providers.standard.operators.bash import BashOperator
from airflow.utils import timezone
from airflow.utils.state import State
from airflow.utils.state import State, TaskInstanceState

pytestmark = pytest.mark.db_test

Expand Down Expand Up @@ -87,3 +87,30 @@ def test_taskcallback_to_json_with_start_date_and_end_date(self, session, create
json_str = input.to_json()
result = TaskCallbackRequest.from_json(json_str)
assert input == result

@pytest.mark.parametrize(
"task_callback_type,expected_is_failure",
[
(None, True),
(TaskInstanceState.FAILED, True),
(TaskInstanceState.UP_FOR_RETRY, True),
(TaskInstanceState.UPSTREAM_FAILED, True),
(TaskInstanceState.SUCCESS, False),
(TaskInstanceState.RUNNING, False),
],
)
def test_is_failure_callback_property(
self, task_callback_type, expected_is_failure, create_task_instance
):
"""Test is_failure_callback property with different task callback types"""
ti = create_task_instance()

request = TaskCallbackRequest(
filepath="filepath",
ti=ti,
bundle_name="testing",
bundle_version=None,
task_callback_type=task_callback_type,
)

assert request.is_failure_callback == expected_is_failure
Loading