Skip to content

Commit 2e5cfa9

Browse files
committed
fix tests
1 parent ea45411 commit 2e5cfa9

File tree

3 files changed

+49
-58
lines changed

3 files changed

+49
-58
lines changed

airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,6 @@ def post_clear_task_instances(
719719
clear_task_instances(
720720
task_instances,
721721
session,
722-
dag,
723722
DagRunState.QUEUED if reset_dag_runs else False,
724723
)
725724

airflow-core/tests/unit/models/test_dag.py

Lines changed: 40 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -180,12 +180,6 @@ def teardown_method(self) -> None:
180180
clear_db_dags()
181181
clear_db_assets()
182182

183-
@staticmethod
184-
def _clean_up(dag_id: str):
185-
with create_session() as session:
186-
session.query(DagRun).filter(DagRun.dag_id == dag_id).delete(synchronize_session=False)
187-
session.query(TI).filter(TI.dag_id == dag_id).delete(synchronize_session=False)
188-
189183
@staticmethod
190184
def _occur_before(a, b, list_):
191185
"""
@@ -977,8 +971,6 @@ def add_failed_dag_run(dag, id, logical_date):
977971
)
978972
add_failed_dag_run(dag, "2", TEST_DATE + timedelta(days=1))
979973
assert dag.get_is_paused()
980-
dag.clear()
981-
self._clean_up(dag_id)
982974

983975
def test_dag_is_deactivated_upon_dagfile_deletion(self, dag_maker):
984976
dag_id = "old_existing_dag"
@@ -1038,8 +1030,6 @@ def test_schedule_dag_no_previous_runs(self):
10381030
)
10391031
assert dag_run.state == State.RUNNING
10401032
assert dag_run.run_type != DagRunType.MANUAL
1041-
dag.clear()
1042-
self._clean_up(dag_id)
10431033

10441034
@patch("airflow.models.dag.Stats")
10451035
def test_dag_handle_callback_crash(self, mock_stats):
@@ -1080,9 +1070,6 @@ def test_dag_handle_callback_crash(self, mock_stats):
10801070
tags={"dag_id": "test_dag_callback_crash"},
10811071
)
10821072

1083-
dag.clear()
1084-
self._clean_up(dag_id)
1085-
10861073
def test_dag_handle_callback_with_removed_task(self, dag_maker, session):
10871074
"""
10881075
Tests avoid crashes when a removed task is the last one in the list of task instance
@@ -1118,9 +1105,6 @@ def test_dag_handle_callback_with_removed_task(self, dag_maker, session):
11181105
dag.handle_callback(dag_run, success=True)
11191106
dag.handle_callback(dag_run, success=False)
11201107

1121-
dag.clear()
1122-
self._clean_up(dag_id)
1123-
11241108
@pytest.mark.parametrize("catchup,expected_next_dagrun", [(True, DEFAULT_DATE), (False, None)])
11251109
def test_next_dagrun_after_fake_scheduled_previous(self, catchup, expected_next_dagrun):
11261110
"""
@@ -1158,8 +1142,6 @@ def test_next_dagrun_after_fake_scheduled_previous(self, catchup, expected_next_
11581142
assert model.next_dagrun == expected_next_dagrun
11591143
assert model.next_dagrun_create_after == expected_next_dagrun + delta
11601144

1161-
self._clean_up(dag_id)
1162-
11631145
def test_schedule_dag_once(self):
11641146
"""
11651147
Tests scheduling a dag scheduled for @once - should be scheduled the first time
@@ -1188,7 +1170,6 @@ def test_schedule_dag_once(self):
11881170

11891171
assert model.next_dagrun is None
11901172
assert model.next_dagrun_create_after is None
1191-
self._clean_up(dag_id)
11921173

11931174
def test_fractional_seconds(self):
11941175
"""
@@ -1213,7 +1194,6 @@ def test_fractional_seconds(self):
12131194

12141195
assert start_date == run.logical_date, "dag run logical_date loses precision"
12151196
assert start_date == run.start_date, "dag run start_date loses precision "
1216-
self._clean_up(dag_id)
12171197

12181198
def test_rich_comparison_ops(self):
12191199
test_dag_id = "test_rich_comparison_ops"
@@ -1397,46 +1377,39 @@ def test_dag_add_task_sets_default_task_group(self):
13971377
assert dag.get_task("task_group.task_with_task_group") == task_with_task_group
13981378

13991379
@pytest.mark.parametrize("dag_run_state", [DagRunState.QUEUED, DagRunState.RUNNING])
1400-
def test_clear_set_dagrun_state(self, dag_run_state):
1380+
@pytest.mark.need_serialized_dag
1381+
def test_clear_set_dagrun_state(self, dag_run_state, dag_maker, session):
14011382
dag_id = "test_clear_set_dagrun_state"
1402-
self._clean_up(dag_id)
1403-
task_id = "t1"
1404-
dag = DAG(dag_id, schedule=None, start_date=DEFAULT_DATE, max_active_runs=1)
1405-
t_1 = EmptyOperator(task_id=task_id, dag=dag)
14061383

1407-
session = settings.Session()
1408-
dagrun_1 = _create_dagrun(
1409-
dag,
1384+
with dag_maker(dag_id, start_date=DEFAULT_DATE, max_active_runs=1) as dag:
1385+
task_id = "t1"
1386+
EmptyOperator(task_id=task_id)
1387+
1388+
dr = dag_maker.create_dagrun(
14101389
run_type=DagRunType.BACKFILL_JOB,
14111390
state=State.FAILED,
14121391
start_date=DEFAULT_DATE,
14131392
logical_date=DEFAULT_DATE,
1414-
data_interval=(DEFAULT_DATE, DEFAULT_DATE),
1393+
session=session,
14151394
)
1416-
session.merge(dagrun_1)
1417-
1418-
task_instance_1 = TI(t_1, run_id=dagrun_1.run_id, state=State.RUNNING)
1419-
task_instance_1.refresh_from_db()
1420-
session.merge(task_instance_1)
14211395
session.commit()
1396+
session.refresh(dr)
1397+
assert dr.state == "failed"
14221398

14231399
dag.clear(
14241400
start_date=DEFAULT_DATE,
14251401
end_date=DEFAULT_DATE + datetime.timedelta(days=1),
14261402
dag_run_state=dag_run_state,
14271403
session=session,
14281404
)
1429-
dagruns = session.query(DagRun).filter(DagRun.dag_id == dag_id).all()
1430-
1431-
assert len(dagruns) == 1
1432-
dagrun: DagRun = dagruns[0]
1433-
assert dagrun.state == dag_run_state
1405+
session.refresh(dr)
1406+
assert dr.state == dag_run_state
14341407

14351408
@pytest.mark.parametrize("dag_run_state", [DagRunState.QUEUED, DagRunState.RUNNING])
14361409
@pytest.mark.need_serialized_dag
14371410
def test_clear_set_dagrun_state_for_mapped_task(self, dag_maker, dag_run_state):
14381411
dag_id = "test_clear_set_dagrun_state"
1439-
self._clean_up(dag_id)
1412+
14401413
task_id = "t1"
14411414

14421415
with dag_maker(dag_id, schedule=None, start_date=DEFAULT_DATE, max_active_runs=1) as dag:
@@ -1611,32 +1584,37 @@ def test_clear_dag(
16111584
self,
16121585
ti_state_begin: TaskInstanceState | None,
16131586
ti_state_end: TaskInstanceState | None,
1587+
dag_maker,
1588+
session,
16141589
):
16151590
dag_id = "test_clear_dag"
1616-
self._clean_up(dag_id)
1591+
16171592
task_id = "t1"
1618-
dag = DAG(dag_id, schedule=None, start_date=DEFAULT_DATE, max_active_runs=1)
1619-
_ = EmptyOperator(task_id=task_id, dag=dag)
1593+
with dag_maker(
1594+
dag_id,
1595+
schedule=None,
1596+
start_date=DEFAULT_DATE,
1597+
max_active_runs=1,
1598+
serialized=True,
1599+
) as dag:
1600+
EmptyOperator(task_id=task_id)
16201601

16211602
session = settings.Session() # type: ignore
1622-
dagrun_1 = dag.create_dagrun(
1603+
dagrun_1 = dag_maker.create_dagrun(
16231604
run_id="backfill",
16241605
run_type=DagRunType.BACKFILL_JOB,
16251606
state=DagRunState.RUNNING,
16261607
start_date=DEFAULT_DATE,
16271608
logical_date=DEFAULT_DATE,
1628-
data_interval=(DEFAULT_DATE, DEFAULT_DATE),
1629-
run_after=DEFAULT_DATE,
1630-
triggered_by=DagRunTriggeredByType.TEST,
1609+
# triggered_by=DagRunTriggeredByType.TEST,
1610+
session=session,
16311611
)
1632-
session.merge(dagrun_1)
16331612

1634-
task_instance_1 = dagrun_1.get_task_instance(task_id)
1613+
task_instance_1 = dagrun_1.get_task_instance(task_id, session=session)
16351614
if TYPE_CHECKING:
16361615
assert task_instance_1
16371616
task_instance_1.state = ti_state_begin
16381617
task_instance_1.job_id = 123
1639-
session.merge(task_instance_1)
16401618
session.commit()
16411619

16421620
dag.clear(
@@ -1650,7 +1628,6 @@ def test_clear_dag(
16501628
assert len(task_instances) == 1
16511629
task_instance: TI = task_instances[0]
16521630
assert task_instance.state == ti_state_end
1653-
self._clean_up(dag_id)
16541631

16551632
def test_next_dagrun_info_once(self):
16561633
dag = DAG("test_scheduler_dagrun_once", start_date=timezone.datetime(2015, 1, 1), schedule="@once")
@@ -2508,7 +2485,12 @@ def test_count_number_queries(self, tasks_count):
25082485
def test_set_task_instance_state(run_id, session, dag_maker):
25092486
"""Test that set_task_instance_state updates the TaskInstance state and clear downstream failed"""
25102487
start_date = datetime_tz(2020, 1, 1)
2511-
with dag_maker("test_set_task_instance_state", start_date=start_date, session=session) as dag:
2488+
with dag_maker(
2489+
"test_set_task_instance_state",
2490+
start_date=start_date,
2491+
session=session,
2492+
serialized=True,
2493+
) as dag:
25122494
task_1 = EmptyOperator(task_id="task_1")
25132495
task_2 = EmptyOperator(task_id="task_2")
25142496
task_3 = EmptyOperator(task_id="task_3")
@@ -2646,7 +2628,12 @@ def consumer(value):
26462628
def test_set_task_group_state(session, dag_maker):
26472629
"""Test that set_task_group_state updates the TaskGroup state and clear downstream failed"""
26482630
start_date = datetime_tz(2020, 1, 1)
2649-
with dag_maker("test_set_task_group_state", start_date=start_date, session=session) as dag:
2631+
with dag_maker(
2632+
"test_set_task_group_state",
2633+
start_date=start_date,
2634+
session=session,
2635+
serialized=True,
2636+
) as dag:
26502637
start = EmptyOperator(task_id="start")
26512638

26522639
with TaskGroup("section_1", tooltip="Tasks for section_1") as section_1:

airflow-core/tests/unit/models/test_mappedoperator.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -397,19 +397,24 @@ def test_expand_mapped_task_instance_with_named_index(
397397
expected_rendered_names,
398398
) -> None:
399399
"""Test that the correct number of downstream tasks are generated when mapping with an XComArg"""
400-
with dag_maker("test-dag", session=session, start_date=DEFAULT_DATE):
400+
dag_id = "test_dag_12345"
401+
with dag_maker(
402+
dag_id=dag_id,
403+
start_date=DEFAULT_DATE,
404+
serialized=True,
405+
):
401406
create_mapped_task(task_id="task1", map_names=["a", "b"], template=template)
402407

403-
dr = dag_maker.create_dagrun()
404-
tis = dr.get_task_instances()
408+
dr = dag_maker.create_dagrun(session=session)
409+
tis = dr.get_task_instances(session=session)
405410
for ti in tis:
406411
ti.run()
407412
session.flush()
408413

409414
indices = session.scalars(
410415
select(TaskInstance.rendered_map_index)
411416
.where(
412-
TaskInstance.dag_id == "test-dag",
417+
TaskInstance.dag_id == dag_id,
413418
TaskInstance.task_id == "task1",
414419
TaskInstance.run_id == dr.run_id,
415420
)

0 commit comments

Comments
 (0)