Skip to content

Commit ef80507

Browse files
authored
Restore proper DAG callback execution context (#53684)
This ensures DAG callbacks receive the same rich context as task callbacks, improving consistency and providing access to template variables and macros similar to Airflow 2. This has been a blocker for few users similar to #53058 Fixes #52824 Fixes #51402 Closes #51949 Related to #53654 Related to #53618
1 parent 194da6a commit ef80507

File tree

10 files changed

+868
-33
lines changed

10 files changed

+868
-33
lines changed

airflow-core/src/airflow/callbacks/callback_requests.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,19 @@ def is_failure_callback(self) -> bool:
7777
}
7878

7979

80+
class DagRunContext(BaseModel):
81+
"""Class to pass context info from the server to build a Execution context object."""
82+
83+
dag_run: ti_datamodel.DagRun | None = None
84+
last_ti: ti_datamodel.TaskInstance | None = None
85+
86+
8087
class DagCallbackRequest(BaseCallbackRequest):
8188
"""A Class with information about the success/failure DAG callback to be executed."""
8289

8390
dag_id: str
8491
run_id: str
92+
context_from_server: DagRunContext | None = None
8593
is_failure_callback: bool | None = True
8694
"""Flag to determine whether it is a Failure Callback or Success Callback"""
8795
type: Literal["DagCallbackRequest"] = "DagCallbackRequest"

airflow-core/src/airflow/dag_processing/processor.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,12 @@
4040
DeleteVariable,
4141
ErrorResponse,
4242
GetConnection,
43+
GetPreviousDagRun,
44+
GetPrevSuccessfulDagRun,
4345
GetVariable,
4446
OKResponse,
47+
PreviousDagRunResult,
48+
PrevSuccessfulDagRunResult,
4549
PutVariable,
4650
VariableResult,
4751
)
@@ -94,12 +98,24 @@ class DagFileParsingResult(BaseModel):
9498

9599

96100
ToManager = Annotated[
97-
DagFileParsingResult | GetConnection | GetVariable | PutVariable | DeleteVariable,
101+
DagFileParsingResult
102+
| GetConnection
103+
| GetVariable
104+
| PutVariable
105+
| DeleteVariable
106+
| GetPrevSuccessfulDagRun
107+
| GetPreviousDagRun,
98108
Field(discriminator="type"),
99109
]
100110

101111
ToDagProcessor = Annotated[
102-
DagFileParseRequest | ConnectionResult | VariableResult | ErrorResponse | OKResponse,
112+
DagFileParseRequest
113+
| ConnectionResult
114+
| VariableResult
115+
| PreviousDagRunResult
116+
| PrevSuccessfulDagRunResult
117+
| ErrorResponse
118+
| OKResponse,
103119
Field(discriminator="type"),
104120
]
105121

@@ -209,6 +225,8 @@ def _execute_callbacks(
209225

210226

211227
def _execute_dag_callbacks(dagbag: DagBag, request: DagCallbackRequest, log: FilteringBoundLogger) -> None:
228+
from airflow.sdk.api.datamodels._generated import TIRunContext
229+
212230
dag = dagbag.dags[request.dag_id]
213231

214232
callbacks = dag.on_failure_callback if request.is_failure_callback else dag.on_success_callback
@@ -217,12 +235,27 @@ def _execute_dag_callbacks(dagbag: DagBag, request: DagCallbackRequest, log: Fil
217235
return
218236

219237
callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
220-
# TODO:We need a proper context object!
221-
context: Context = {
222-
"dag": dag,
223-
"run_id": request.run_id,
224-
"reason": request.msg,
225-
}
238+
ctx_from_server = request.context_from_server
239+
240+
if ctx_from_server is not None and ctx_from_server.last_ti is not None:
241+
task = dag.get_task(ctx_from_server.last_ti.task_id)
242+
243+
runtime_ti = RuntimeTaskInstance.model_construct(
244+
**ctx_from_server.last_ti.model_dump(exclude_unset=True),
245+
task=task,
246+
_ti_context_from_server=TIRunContext.model_construct(
247+
dag_run=ctx_from_server.dag_run,
248+
max_tries=task.retries,
249+
),
250+
)
251+
context = runtime_ti.get_template_context()
252+
context["reason"] = request.msg
253+
else:
254+
context: Context = { # type: ignore[no-redef]
255+
"dag": dag,
256+
"run_id": request.run_id,
257+
"reason": request.msg,
258+
}
226259

227260
for callback in callbacks:
228261
log.info(
@@ -383,6 +416,17 @@ def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, req_id: int
383416
self.client.variables.set(msg.key, msg.value, msg.description)
384417
elif isinstance(msg, DeleteVariable):
385418
resp = self.client.variables.delete(msg.key)
419+
elif isinstance(msg, GetPreviousDagRun):
420+
resp = self.client.dag_runs.get_previous(
421+
dag_id=msg.dag_id,
422+
logical_date=msg.logical_date,
423+
state=msg.state,
424+
)
425+
elif isinstance(msg, GetPrevSuccessfulDagRun):
426+
dagrun_resp = self.client.task_instances.get_previous_successful_dagrun(self.id)
427+
dagrun_result = PrevSuccessfulDagRunResult.from_dagrun_response(dagrun_resp)
428+
resp = dagrun_result
429+
dump_opts = {"exclude_unset": True}
386430
else:
387431
log.error("Unhandled request", msg=msg)
388432
self.send_msg(

airflow-core/src/airflow/jobs/scheduler_job_runner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from airflow import settings
4141
from airflow._shared.timezones import timezone
4242
from airflow.api_fastapi.execution_api.datamodels.taskinstance import TIRunContext
43-
from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallbackRequest
43+
from airflow.callbacks.callback_requests import DagCallbackRequest, DagRunContext, TaskCallbackRequest
4444
from airflow.configuration import conf
4545
from airflow.dag_processing.bundles.base import BundleUsageTrackingManager
4646
from airflow.executors import workloads
@@ -1854,6 +1854,10 @@ def _schedule_dag_run(
18541854
run_id=dag_run.run_id,
18551855
bundle_name=dag_model.bundle_name,
18561856
bundle_version=dag_run.bundle_version,
1857+
context_from_server=DagRunContext(
1858+
dag_run=dag_run,
1859+
last_ti=dag_run.get_last_ti(dag=dag, session=session),
1860+
),
18571861
is_failure_callback=True,
18581862
msg="timed_out",
18591863
)

airflow-core/src/airflow/models/dagrun.py

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
from sqlalchemy_utils import UUIDType
6262

6363
from airflow._shared.timezones import timezone
64-
from airflow.callbacks.callback_requests import DagCallbackRequest
64+
from airflow.callbacks.callback_requests import DagCallbackRequest, DagRunContext
6565
from airflow.configuration import conf as airflow_conf
6666
from airflow.exceptions import AirflowException, TaskNotFound
6767
from airflow.listeners.listener import get_listener_manager
@@ -102,7 +102,7 @@
102102
from airflow.models.dag import DAG
103103
from airflow.models.dag_version import DagVersion
104104
from airflow.models.taskinstancekey import TaskInstanceKey
105-
from airflow.sdk import DAG as SDKDAG, Context
105+
from airflow.sdk import DAG as SDKDAG
106106
from airflow.sdk.types import Operator
107107
from airflow.serialization.serialized_objects import SerializedBaseOperator as BaseOperator
108108
from airflow.utils.types import ArgNotSet
@@ -1186,6 +1186,10 @@ def recalculate(self) -> _UnfinishedStates:
11861186
run_id=self.run_id,
11871187
bundle_name=self.dag_model.bundle_name,
11881188
bundle_version=self.bundle_version,
1189+
context_from_server=DagRunContext(
1190+
dag_run=self,
1191+
last_ti=self.get_last_ti(dag=dag, session=session),
1192+
),
11891193
is_failure_callback=True,
11901194
msg="task_failure",
11911195
)
@@ -1215,6 +1219,10 @@ def recalculate(self) -> _UnfinishedStates:
12151219
run_id=self.run_id,
12161220
bundle_name=self.dag_model.bundle_name,
12171221
bundle_version=self.bundle_version,
1222+
context_from_server=DagRunContext(
1223+
dag_run=self,
1224+
last_ti=self.get_last_ti(dag=dag, session=session),
1225+
),
12181226
is_failure_callback=False,
12191227
msg="success",
12201228
)
@@ -1238,6 +1246,10 @@ def recalculate(self) -> _UnfinishedStates:
12381246
run_id=self.run_id,
12391247
bundle_name=self.dag_model.bundle_name,
12401248
bundle_version=self.bundle_version,
1249+
context_from_server=DagRunContext(
1250+
dag_run=self,
1251+
last_ti=self.get_last_ti(dag=dag, session=session),
1252+
),
12411253
is_failure_callback=True,
12421254
msg="all_tasks_deadlocked",
12431255
)
@@ -1350,13 +1362,72 @@ def notify_dagrun_state_changed(self, msg: str = ""):
13501362
# we can't get all the state changes on SchedulerJob,
13511363
# or LocalTaskJob, so we don't want to "falsely advertise" we notify about that
13521364

1365+
@provide_session
1366+
def get_last_ti(self, dag: DAG, session: Session = NEW_SESSION) -> TI | None:
1367+
"""Get Last TI from the dagrun to build and pass Execution context object from server to then run callbacks."""
1368+
tis = self.get_task_instances(session=session)
1369+
# tis from a dagrun may not be a part of dag.partial_subset,
1370+
# since dag.partial_subset is a subset of the dag.
1371+
# This ensures that we will only use the accessible TI
1372+
# context for the callback.
1373+
if dag.partial:
1374+
tis = [ti for ti in tis if not ti.state == State.NONE]
1375+
# filter out removed tasks
1376+
tis = [ti for ti in tis if ti.state != TaskInstanceState.REMOVED]
1377+
if not tis:
1378+
return None
1379+
ti = tis[-1] # get last TaskInstance of DagRun
1380+
return ti
1381+
13531382
def handle_dag_callback(self, dag: SDKDAG, success: bool = True, reason: str = "success"):
13541383
"""Only needed for `dag.test` where `execute_callbacks=True` is passed to `update_state`."""
1355-
context: Context = {
1356-
"dag": dag,
1357-
"run_id": str(self.run_id),
1358-
"reason": reason,
1359-
}
1384+
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
1385+
DagRun as DRDataModel,
1386+
TaskInstance as TIDataModel,
1387+
TIRunContext,
1388+
)
1389+
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
1390+
1391+
last_ti = self.get_last_ti(dag) # type: ignore[arg-type]
1392+
if last_ti:
1393+
last_ti_model = TIDataModel.model_validate(last_ti, from_attributes=True)
1394+
task = dag.get_task(last_ti.task_id)
1395+
1396+
dag_run_data = DRDataModel(
1397+
dag_id=self.dag_id,
1398+
run_id=self.run_id,
1399+
logical_date=self.logical_date,
1400+
data_interval_start=self.data_interval_start,
1401+
data_interval_end=self.data_interval_end,
1402+
run_after=self.run_after,
1403+
start_date=self.start_date,
1404+
end_date=self.end_date,
1405+
run_type=self.run_type,
1406+
state=self.state,
1407+
conf=self.conf,
1408+
consumed_asset_events=[],
1409+
)
1410+
1411+
runtime_ti = RuntimeTaskInstance.model_construct(
1412+
**last_ti_model.model_dump(exclude_unset=True),
1413+
task=task,
1414+
_ti_context_from_server=TIRunContext(
1415+
dag_run=dag_run_data,
1416+
max_tries=last_ti.max_tries,
1417+
variables=[],
1418+
connections=[],
1419+
xcom_keys_to_clear=[],
1420+
),
1421+
max_tries=last_ti.max_tries,
1422+
)
1423+
context = runtime_ti.get_template_context()
1424+
else:
1425+
context = {
1426+
"dag": dag,
1427+
"run_id": self.run_id,
1428+
}
1429+
1430+
context["reason"] = reason
13601431

13611432
callbacks = dag.on_success_callback if success else dag.on_failure_callback
13621433
if not callbacks:

0 commit comments

Comments
 (0)