94
94
from airflow .listeners .listener import get_listener_manager
95
95
from airflow .models .asset import AssetActive , AssetEvent , AssetModel
96
96
from airflow .models .base import Base , StringID , TaskInstanceDependencies
97
- from airflow .models .dagbag import DagBag
98
97
from airflow .models .log import Log
99
98
from airflow .models .renderedtifields import get_serialized_template_fields
100
99
from airflow .models .taskinstancekey import TaskInstanceKey
@@ -255,7 +254,6 @@ def _stop_remaining_tasks(*, task_instance: TaskInstance, task_teardown_map=None
255
254
def clear_task_instances (
256
255
tis : list [TaskInstance ],
257
256
session : Session ,
258
- dag : DAG | None = None ,
259
257
dag_run_state : DagRunState | Literal [False ] = DagRunState .QUEUED ,
260
258
) -> None :
261
259
"""
@@ -271,11 +269,13 @@ def clear_task_instances(
271
269
:param session: current session
272
270
:param dag_run_state: state to set finished DagRuns to.
273
271
If set to False, DagRuns state will not be changed.
274
- :param dag: DAG object
272
+
273
+ :meta private:
275
274
"""
276
- # taskinstance uuids:
277
275
task_instance_ids : list [str ] = []
278
- dag_bag = DagBag (read_dags_from_db = True )
276
+ from airflow .jobs .scheduler_job_runner import SchedulerDagBag
277
+
278
+ scheduler_dagbag = SchedulerDagBag ()
279
279
280
280
for ti in tis :
281
281
task_instance_ids .append (ti .id )
@@ -285,7 +285,12 @@ def clear_task_instances(
285
285
# the task is terminated and becomes eligible for retry.
286
286
ti .state = TaskInstanceState .RESTARTING
287
287
else :
288
- ti_dag = dag if dag and dag .dag_id == ti .dag_id else dag_bag .get_dag (ti .dag_id , session = session )
288
+ dr = ti .dag_run
289
+ ti_dag = scheduler_dagbag .get_dag (dag_run = dr , session = session )
290
+ if not ti_dag :
291
+ raise AirflowException (
292
+ f"Serialized dag not found for dag run. dag_id={ dr .dag_id } run_id={ dr .run_id } "
293
+ )
289
294
task_id = ti .task_id
290
295
if ti_dag and ti_dag .has_task (task_id ):
291
296
task = ti_dag .get_task (task_id )
@@ -327,15 +332,18 @@ def clear_task_instances(
327
332
if dr .state in State .finished_dr_states :
328
333
dr .state = dag_run_state
329
334
dr .start_date = timezone .utcnow ()
330
- if TYPE_CHECKING :
331
- assert dag # todo: change signature so this is required
332
- if not dag .disable_bundle_versioning :
335
+ dr_dag = scheduler_dagbag .get_dag (dag_run = dr , session = session )
336
+ if not dr_dag :
337
+ raise AirflowException (
338
+ f"Serialized dag not found for dag run. dag_id={ dr .dag_id } run_id={ dr .run_id } "
339
+ )
340
+ if not dr_dag .disable_bundle_versioning :
333
341
if dr .dag_model :
334
342
bundle_version = dr .dag_model .bundle_version
335
343
else :
336
344
bundle_version = session .scalar (
337
345
select (DagModel .bundle_version ).where (
338
- DagModel .dag_id == dag .dag_id ,
346
+ DagModel .dag_id == dr_dag .dag_id ,
339
347
)
340
348
)
341
349
if bundle_version is not None :
0 commit comments