Skip to content

Commit 9cb63eb

Browse files
committed
Restore proper DAG callback execution context
This ensures DAG callbacks receive the same rich context as task callbacks, improving consistency and providing access to template variables and macros similar to Airflow 2. This has been a blocker for few users similar to apache#53058 Fixes apache#52824 Fixes apache#51402 Closes apache#51949 Related to apache#53654 Related to apache#53618
1 parent 2cb6079 commit 9cb63eb

File tree

9 files changed

+857
-22
lines changed

9 files changed

+857
-22
lines changed

airflow-core/src/airflow/callbacks/callback_requests.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,19 @@ def is_failure_callback(self) -> bool:
7777
}
7878

7979

80+
class DagRunContext(BaseModel):
81+
"""Class to pass context info from the server to build a Execution context object."""
82+
83+
dag_run: ti_datamodel.DagRun | None = None
84+
last_ti: ti_datamodel.TaskInstance | None = None
85+
86+
8087
class DagCallbackRequest(BaseCallbackRequest):
8188
"""A Class with information about the success/failure DAG callback to be executed."""
8289

8390
dag_id: str
8491
run_id: str
92+
context_from_server: DagRunContext | None = None
8593
is_failure_callback: bool | None = True
8694
"""Flag to determine whether it is a Failure Callback or Success Callback"""
8795
type: Literal["DagCallbackRequest"] = "DagCallbackRequest"

airflow-core/src/airflow/dag_processing/processor.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,12 @@
4040
DeleteVariable,
4141
ErrorResponse,
4242
GetConnection,
43+
GetPreviousDagRun,
44+
GetPrevSuccessfulDagRun,
4345
GetVariable,
4446
OKResponse,
47+
PreviousDagRunResult,
48+
PrevSuccessfulDagRunResult,
4549
PutVariable,
4650
VariableResult,
4751
)
@@ -94,12 +98,24 @@ class DagFileParsingResult(BaseModel):
9498

9599

96100
ToManager = Annotated[
97-
DagFileParsingResult | GetConnection | GetVariable | PutVariable | DeleteVariable,
101+
DagFileParsingResult
102+
| GetConnection
103+
| GetVariable
104+
| PutVariable
105+
| DeleteVariable
106+
| GetPrevSuccessfulDagRun
107+
| GetPreviousDagRun,
98108
Field(discriminator="type"),
99109
]
100110

101111
ToDagProcessor = Annotated[
102-
DagFileParseRequest | ConnectionResult | VariableResult | ErrorResponse | OKResponse,
112+
DagFileParseRequest
113+
| ConnectionResult
114+
| VariableResult
115+
| PreviousDagRunResult
116+
| PrevSuccessfulDagRunResult
117+
| ErrorResponse
118+
| OKResponse,
103119
Field(discriminator="type"),
104120
]
105121

@@ -209,6 +225,8 @@ def _execute_callbacks(
209225

210226

211227
def _execute_dag_callbacks(dagbag: DagBag, request: DagCallbackRequest, log: FilteringBoundLogger) -> None:
228+
from airflow.sdk.api.datamodels._generated import TIRunContext
229+
212230
dag = dagbag.dags[request.dag_id]
213231

214232
callbacks = dag.on_failure_callback if request.is_failure_callback else dag.on_success_callback
@@ -217,12 +235,27 @@ def _execute_dag_callbacks(dagbag: DagBag, request: DagCallbackRequest, log: Fil
217235
return
218236

219237
callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
220-
# TODO:We need a proper context object!
221-
context: Context = {
222-
"dag": dag,
223-
"run_id": request.run_id,
224-
"reason": request.msg,
225-
}
238+
ctx_from_server = request.context_from_server
239+
240+
if ctx_from_server is not None and ctx_from_server.last_ti is not None:
241+
task = dag.get_task(ctx_from_server.last_ti.task_id)
242+
243+
runtime_ti = RuntimeTaskInstance.model_construct(
244+
**ctx_from_server.last_ti.model_dump(exclude_unset=True),
245+
task=task,
246+
_ti_context_from_server=TIRunContext.model_construct(
247+
dag_run=ctx_from_server.dag_run,
248+
max_tries=task.retries,
249+
),
250+
)
251+
context = runtime_ti.get_template_context()
252+
context["reason"] = request.msg
253+
else:
254+
context: Context = { # type: ignore[no-redef]
255+
"dag": dag,
256+
"run_id": request.run_id,
257+
"reason": request.msg,
258+
}
226259

227260
for callback in callbacks:
228261
log.info(
@@ -383,6 +416,17 @@ def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, req_id: int
383416
self.client.variables.set(msg.key, msg.value, msg.description)
384417
elif isinstance(msg, DeleteVariable):
385418
resp = self.client.variables.delete(msg.key)
419+
elif isinstance(msg, GetPreviousDagRun):
420+
resp = self.client.dag_runs.get_previous(
421+
dag_id=msg.dag_id,
422+
logical_date=msg.logical_date,
423+
state=msg.state,
424+
)
425+
elif isinstance(msg, GetPrevSuccessfulDagRun):
426+
dagrun_resp = self.client.task_instances.get_previous_successful_dagrun(self.id)
427+
dagrun_result = PrevSuccessfulDagRunResult.from_dagrun_response(dagrun_resp)
428+
resp = dagrun_result
429+
dump_opts = {"exclude_unset": True}
386430
else:
387431
log.error("Unhandled request", msg=msg)
388432
self.send_msg(

airflow-core/src/airflow/jobs/scheduler_job_runner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from airflow import settings
4141
from airflow._shared.timezones import timezone
4242
from airflow.api_fastapi.execution_api.datamodels.taskinstance import TIRunContext
43-
from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallbackRequest
43+
from airflow.callbacks.callback_requests import DagCallbackRequest, DagRunContext, TaskCallbackRequest
4444
from airflow.configuration import conf
4545
from airflow.dag_processing.bundles.base import BundleUsageTrackingManager
4646
from airflow.executors import workloads
@@ -1854,6 +1854,10 @@ def _schedule_dag_run(
18541854
run_id=dag_run.run_id,
18551855
bundle_name=dag_model.bundle_name,
18561856
bundle_version=dag_run.bundle_version,
1857+
context_from_server=DagRunContext(
1858+
dag_run=dag_run,
1859+
last_ti=dag_run.get_last_ti(dag=dag, session=session),
1860+
),
18571861
is_failure_callback=True,
18581862
msg="timed_out",
18591863
)

airflow-core/src/airflow/models/dagrun.py

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
from sqlalchemy_utils import UUIDType
6262

6363
from airflow._shared.timezones import timezone
64-
from airflow.callbacks.callback_requests import DagCallbackRequest
64+
from airflow.callbacks.callback_requests import DagCallbackRequest, DagRunContext
6565
from airflow.configuration import conf as airflow_conf
6666
from airflow.exceptions import AirflowException, TaskNotFound
6767
from airflow.listeners.listener import get_listener_manager
@@ -102,7 +102,7 @@
102102
from airflow.models.dag import DAG
103103
from airflow.models.dag_version import DagVersion
104104
from airflow.models.taskinstancekey import TaskInstanceKey
105-
from airflow.sdk import DAG as SDKDAG, Context
105+
from airflow.sdk import DAG as SDKDAG
106106
from airflow.sdk.types import Operator
107107
from airflow.serialization.serialized_objects import SerializedBaseOperator as BaseOperator
108108
from airflow.utils.types import ArgNotSet
@@ -1186,6 +1186,10 @@ def recalculate(self) -> _UnfinishedStates:
11861186
run_id=self.run_id,
11871187
bundle_name=self.dag_model.bundle_name,
11881188
bundle_version=self.bundle_version,
1189+
context_from_server=DagRunContext(
1190+
dag_run=self,
1191+
last_ti=self.get_last_ti(dag=dag, session=session),
1192+
),
11891193
is_failure_callback=True,
11901194
msg="task_failure",
11911195
)
@@ -1215,6 +1219,10 @@ def recalculate(self) -> _UnfinishedStates:
12151219
run_id=self.run_id,
12161220
bundle_name=self.dag_model.bundle_name,
12171221
bundle_version=self.bundle_version,
1222+
context_from_server=DagRunContext(
1223+
dag_run=self,
1224+
last_ti=self.get_last_ti(dag=dag, session=session),
1225+
),
12181226
is_failure_callback=False,
12191227
msg="success",
12201228
)
@@ -1238,6 +1246,10 @@ def recalculate(self) -> _UnfinishedStates:
12381246
run_id=self.run_id,
12391247
bundle_name=self.dag_model.bundle_name,
12401248
bundle_version=self.bundle_version,
1249+
context_from_server=DagRunContext(
1250+
dag_run=self,
1251+
last_ti=self.get_last_ti(dag=dag, session=session),
1252+
),
12411253
is_failure_callback=True,
12421254
msg="all_tasks_deadlocked",
12431255
)
@@ -1350,13 +1362,62 @@ def notify_dagrun_state_changed(self, msg: str = ""):
13501362
# we can't get all the state changes on SchedulerJob,
13511363
# or LocalTaskJob, so we don't want to "falsely advertise" we notify about that
13521364

1365+
@provide_session
1366+
def get_last_ti(self, dag: DAG, session: Session = NEW_SESSION) -> TI:
1367+
"""Get Last TI from the dagrun to build and pass Execution context object from server to then run callbacks."""
1368+
tis = self.get_task_instances(session=session)
1369+
# tis from a dagrun may not be a part of dag.partial_subset,
1370+
# since dag.partial_subset is a subset of the dag.
1371+
# This ensures that we will only use the accessible TI
1372+
# context for the callback.
1373+
if dag.partial:
1374+
tis = [ti for ti in tis if not ti.state == State.NONE]
1375+
# filter out removed tasks
1376+
tis = [ti for ti in tis if ti.state != TaskInstanceState.REMOVED]
1377+
ti = tis[-1] # get last TaskInstance of DagRun
1378+
return ti
1379+
13531380
def handle_dag_callback(self, dag: SDKDAG, success: bool = True, reason: str = "success"):
13541381
"""Only needed for `dag.test` where `execute_callbacks=True` is passed to `update_state`."""
1355-
context: Context = {
1356-
"dag": dag,
1357-
"run_id": str(self.run_id),
1358-
"reason": reason,
1359-
}
1382+
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
1383+
DagRun as DRDataModel,
1384+
TaskInstance as TIDataModel,
1385+
TIRunContext,
1386+
)
1387+
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
1388+
1389+
first_ti = TIDataModel.model_validate(self.get_last_ti(dag), from_attributes=True) # type: ignore[arg-type]
1390+
task = dag.get_task(first_ti.task_id)
1391+
1392+
dag_run_data = DRDataModel(
1393+
dag_id=self.dag_id,
1394+
run_id=self.run_id,
1395+
logical_date=self.logical_date,
1396+
data_interval_start=self.data_interval_start,
1397+
data_interval_end=self.data_interval_end,
1398+
run_after=self.run_after,
1399+
start_date=self.start_date,
1400+
end_date=self.end_date,
1401+
run_type=self.run_type,
1402+
state=self.state,
1403+
conf=self.conf,
1404+
consumed_asset_events=[],
1405+
)
1406+
1407+
runtime_ti = RuntimeTaskInstance.model_construct(
1408+
**first_ti.model_dump(exclude_unset=True),
1409+
task=task,
1410+
_ti_context_from_server=TIRunContext(
1411+
dag_run=dag_run_data,
1412+
max_tries=task.retries,
1413+
variables=[],
1414+
connections=[],
1415+
xcom_keys_to_clear=[],
1416+
),
1417+
max_tries=task.retries or 0,
1418+
)
1419+
context = runtime_ti.get_template_context()
1420+
context["reason"] = reason
13601421

13611422
callbacks = dag.on_success_callback if success else dag.on_failure_callback
13621423
if not callbacks:

0 commit comments

Comments
 (0)