61
61
from sqlalchemy_utils import UUIDType
62
62
63
63
from airflow ._shared .timezones import timezone
64
- from airflow .callbacks .callback_requests import DagCallbackRequest
64
+ from airflow .callbacks .callback_requests import DagCallbackRequest , DagRunContext
65
65
from airflow .configuration import conf as airflow_conf
66
66
from airflow .exceptions import AirflowException , TaskNotFound
67
67
from airflow .listeners .listener import get_listener_manager
102
102
from airflow .models .dag import DAG
103
103
from airflow .models .dag_version import DagVersion
104
104
from airflow .models .taskinstancekey import TaskInstanceKey
105
- from airflow .sdk import DAG as SDKDAG , Context
105
+ from airflow .sdk import DAG as SDKDAG
106
106
from airflow .sdk .types import Operator
107
107
from airflow .serialization .serialized_objects import SerializedBaseOperator as BaseOperator
108
108
from airflow .utils .types import ArgNotSet
@@ -1186,6 +1186,10 @@ def recalculate(self) -> _UnfinishedStates:
1186
1186
run_id = self .run_id ,
1187
1187
bundle_name = self .dag_model .bundle_name ,
1188
1188
bundle_version = self .bundle_version ,
1189
+ context_from_server = DagRunContext (
1190
+ dag_run = self ,
1191
+ first_ti = self .get_first_ti (dag = dag , session = session ),
1192
+ ),
1189
1193
is_failure_callback = True ,
1190
1194
msg = "task_failure" ,
1191
1195
)
@@ -1215,6 +1219,10 @@ def recalculate(self) -> _UnfinishedStates:
1215
1219
run_id = self .run_id ,
1216
1220
bundle_name = self .dag_model .bundle_name ,
1217
1221
bundle_version = self .bundle_version ,
1222
+ context_from_server = DagRunContext (
1223
+ dag_run = self ,
1224
+ first_ti = self .get_first_ti (dag = dag , session = session ),
1225
+ ),
1218
1226
is_failure_callback = False ,
1219
1227
msg = "success" ,
1220
1228
)
@@ -1238,6 +1246,10 @@ def recalculate(self) -> _UnfinishedStates:
1238
1246
run_id = self .run_id ,
1239
1247
bundle_name = self .dag_model .bundle_name ,
1240
1248
bundle_version = self .bundle_version ,
1249
+ context_from_server = DagRunContext (
1250
+ dag_run = self ,
1251
+ first_ti = self .get_first_ti (dag = dag , session = session ),
1252
+ ),
1241
1253
is_failure_callback = True ,
1242
1254
msg = "all_tasks_deadlocked" ,
1243
1255
)
@@ -1350,13 +1362,62 @@ def notify_dagrun_state_changed(self, msg: str = ""):
1350
1362
# we can't get all the state changes on SchedulerJob,
1351
1363
# or LocalTaskJob, so we don't want to "falsely advertise" we notify about that
1352
1364
1365
+ @provide_session
1366
+ def get_first_ti (self , dag : DAG , session : Session = NEW_SESSION ) -> TI :
1367
+ """Get First TI from the dagrun to build and pass Execution context object from server to then run callbacks."""
1368
+ tis = self .get_task_instances (session = session )
1369
+ # tis from a dagrun may not be a part of dag.partial_subset,
1370
+ # since dag.partial_subset is a subset of the dag.
1371
+ # This ensures that we will only use the accessible TI
1372
+ # context for the callback.
1373
+ if dag .partial :
1374
+ tis = [ti for ti in tis if not ti .state == State .NONE ]
1375
+ # filter out removed tasks
1376
+ tis = [ti for ti in tis if ti .state != TaskInstanceState .REMOVED ]
1377
+ ti = tis [- 1 ] # get first TaskInstance of DagRun
1378
+ return ti
1379
+
1353
1380
def handle_dag_callback (self , dag : SDKDAG , success : bool = True , reason : str = "success" ):
1354
1381
"""Only needed for `dag.test` where `execute_callbacks=True` is passed to `update_state`."""
1355
- context : Context = {
1356
- "dag" : dag ,
1357
- "run_id" : str (self .run_id ),
1358
- "reason" : reason ,
1359
- }
1382
+ from airflow .api_fastapi .execution_api .datamodels .taskinstance import (
1383
+ DagRun as DRDataModel ,
1384
+ TaskInstance as TIDataModel ,
1385
+ TIRunContext ,
1386
+ )
1387
+ from airflow .sdk .execution_time .task_runner import RuntimeTaskInstance
1388
+
1389
+ first_ti = TIDataModel .model_validate (self .get_first_ti (dag ), from_attributes = True ) # type: ignore[arg-type]
1390
+ task = dag .get_task (first_ti .task_id )
1391
+
1392
+ dag_run_data = DRDataModel (
1393
+ dag_id = self .dag_id ,
1394
+ run_id = self .run_id ,
1395
+ logical_date = self .logical_date ,
1396
+ data_interval_start = self .data_interval_start ,
1397
+ data_interval_end = self .data_interval_end ,
1398
+ run_after = self .run_after ,
1399
+ start_date = self .start_date ,
1400
+ end_date = self .end_date ,
1401
+ run_type = self .run_type ,
1402
+ state = self .state ,
1403
+ conf = self .conf ,
1404
+ consumed_asset_events = [],
1405
+ )
1406
+
1407
+ runtime_ti = RuntimeTaskInstance .model_construct (
1408
+ ** first_ti .model_dump (exclude_unset = True ),
1409
+ task = task ,
1410
+ _ti_context_from_server = TIRunContext (
1411
+ dag_run = dag_run_data ,
1412
+ max_tries = task .retries ,
1413
+ variables = [],
1414
+ connections = [],
1415
+ xcom_keys_to_clear = [],
1416
+ ),
1417
+ max_tries = task .retries or 0 ,
1418
+ )
1419
+ context = runtime_ti .get_template_context ()
1420
+ context ["reason" ] = reason
1360
1421
1361
1422
callbacks = dag .on_success_callback if success else dag .on_failure_callback
1362
1423
if not callbacks :
0 commit comments