Skip to content

Commit 7f19372

Browse files
committed
fix more tests
1 parent bcaa368 commit 7f19372

File tree

2 files changed

+51
-41
lines changed

2 files changed

+51
-41
lines changed

providers/standard/tests/unit/standard/operators/test_python.py

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1915,7 +1915,7 @@ class TestShortCircuitWithTeardown:
19151915
def test_short_circuit_with_teardowns(
19161916
self, dag_maker, ignore_downstream_trigger_rules, should_skip, with_teardown, expected
19171917
):
1918-
with dag_maker() as dag:
1918+
with dag_maker(serialized=True):
19191919
op1 = ShortCircuitOperator(
19201920
task_id="op1",
19211921
python_callable=lambda: not should_skip,
@@ -1928,21 +1928,20 @@ def test_short_circuit_with_teardowns(
19281928
op4.as_teardown()
19291929
op1 >> op2 >> op3 >> op4
19301930
op1.skip = MagicMock()
1931-
dagrun = dag_maker.create_dagrun()
1932-
tis = dagrun.get_task_instances()
1933-
ti: TaskInstance = next(x for x in tis if x.task_id == "op1")
1934-
ti._run_raw_task()
1935-
expected_tasks = {dag.task_dict[x] for x in expected}
1931+
dagrun = dag_maker.create_dagrun()
1932+
tis = dagrun.get_task_instances()
1933+
ti: TaskInstance = next(x for x in tis if x.task_id == "op1")
1934+
ti._run_raw_task()
19361935
if should_skip:
19371936
# we can't use assert_called_with because it's a set and therefore not ordered
1938-
actual_skipped = set(op1.skip.call_args.kwargs["tasks"])
1939-
assert actual_skipped == expected_tasks
1937+
actual_skipped = set(x.task_id for x in op1.skip.call_args.kwargs["tasks"])
1938+
assert actual_skipped == set(expected)
19401939
else:
19411940
op1.skip.assert_not_called()
19421941

19431942
@pytest.mark.parametrize("config", ["sequence", "parallel"])
19441943
def test_short_circuit_with_teardowns_complicated(self, dag_maker, config):
1945-
with dag_maker():
1944+
with dag_maker(serialized=True):
19461945
s1 = PythonOperator(task_id="s1", python_callable=print).as_setup()
19471946
s2 = PythonOperator(task_id="s2", python_callable=print).as_setup()
19481947
op1 = ShortCircuitOperator(
@@ -1959,16 +1958,16 @@ def test_short_circuit_with_teardowns_complicated(self, dag_maker, config):
19591958
else:
19601959
raise ValueError("unexpected")
19611960
op1.skip = MagicMock()
1962-
dagrun = dag_maker.create_dagrun()
1963-
tis = dagrun.get_task_instances()
1964-
ti: TaskInstance = next(x for x in tis if x.task_id == "op1")
1965-
ti._run_raw_task()
1966-
# we can't use assert_called_with because it's a set and therefore not ordered
1967-
actual_skipped = set(op1.skip.call_args.kwargs["tasks"])
1968-
assert actual_skipped == {s2, op2}
1961+
dagrun = dag_maker.create_dagrun()
1962+
tis = dagrun.get_task_instances()
1963+
ti: TaskInstance = next(x for x in tis if x.task_id == "op1")
1964+
ti._run_raw_task()
1965+
# we can't use assert_called_with because it's a set and therefore not ordered
1966+
actual_skipped = set(op1.skip.call_args.kwargs["tasks"])
1967+
assert actual_skipped == {s2, op2}
19691968

19701969
def test_short_circuit_with_teardowns_complicated_2(self, dag_maker):
1971-
with dag_maker():
1970+
with dag_maker(serialized=True):
19721971
s1 = PythonOperator(task_id="s1", python_callable=print).as_setup()
19731972
s2 = PythonOperator(task_id="s2", python_callable=print).as_setup()
19741973
op1 = ShortCircuitOperator(
@@ -1986,22 +1985,22 @@ def test_short_circuit_with_teardowns_complicated_2(self, dag_maker):
19861985
# in this case we don't want to skip t2 since it should run
19871986
op1 >> t2
19881987
op1.skip = MagicMock()
1989-
dagrun = dag_maker.create_dagrun()
1990-
tis = dagrun.get_task_instances()
1991-
ti: TaskInstance = next(x for x in tis if x.task_id == "op1")
1992-
ti._run_raw_task()
1993-
# we can't use assert_called_with because it's a set and therefore not ordered
1994-
actual_kwargs = op1.skip.call_args.kwargs
1995-
actual_skipped = set(actual_kwargs["tasks"])
1996-
assert actual_skipped == {op3}
1988+
dagrun = dag_maker.create_dagrun()
1989+
tis = dagrun.get_task_instances()
1990+
ti: TaskInstance = next(x for x in tis if x.task_id == "op1")
1991+
ti._run_raw_task()
1992+
# we can't use assert_called_with because it's a set and therefore not ordered
1993+
actual_kwargs = op1.skip.call_args.kwargs
1994+
actual_skipped = set(actual_kwargs["tasks"])
1995+
assert actual_skipped == {op3}
19971996

19981997
@pytest.mark.parametrize("level", [logging.DEBUG, logging.INFO])
19991998
def test_short_circuit_with_teardowns_debug_level(self, dag_maker, level, clear_db):
20001999
"""
20012000
When logging is debug we convert to a list to log the tasks skipped
20022001
before passing them to the skip method.
20032002
"""
2004-
with dag_maker():
2003+
with dag_maker(serialized=True):
20052004
s1 = PythonOperator(task_id="s1", python_callable=print).as_setup()
20062005
s2 = PythonOperator(task_id="s2", python_callable=print).as_setup()
20072006
op1 = ShortCircuitOperator(
@@ -2020,18 +2019,18 @@ def test_short_circuit_with_teardowns_debug_level(self, dag_maker, level, clear_
20202019
# in this case we don't want to skip t2 since it should run
20212020
op1 >> t2
20222021
op1.skip = MagicMock()
2023-
dagrun = dag_maker.create_dagrun()
2024-
tis = dagrun.get_task_instances()
2025-
ti: TaskInstance = next(x for x in tis if x.task_id == "op1")
2026-
ti._run_raw_task()
2027-
# we can't use assert_called_with because it's a set and therefore not ordered
2028-
actual_kwargs = op1.skip.call_args.kwargs
2029-
actual_skipped = actual_kwargs["tasks"]
2030-
if level <= logging.DEBUG:
2031-
assert isinstance(actual_skipped, list)
2032-
else:
2033-
assert isinstance(actual_skipped, Generator)
2034-
assert set(actual_skipped) == {op3}
2022+
dagrun = dag_maker.create_dagrun()
2023+
tis = dagrun.get_task_instances()
2024+
ti: TaskInstance = next(x for x in tis if x.task_id == "op1")
2025+
ti._run_raw_task()
2026+
# we can't use assert_called_with because it's a set and therefore not ordered
2027+
actual_kwargs = op1.skip.call_args.kwargs
2028+
actual_skipped = actual_kwargs["tasks"]
2029+
if level <= logging.DEBUG:
2030+
assert isinstance(actual_skipped, list)
2031+
else:
2032+
assert isinstance(actual_skipped, Generator)
2033+
assert set(actual_skipped) == {op3}
20352034

20362035

20372036
@pytest.mark.parametrize(

providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from airflow.models import DagBag, DagRun, TaskInstance
3131
from airflow.models.baseoperator import BaseOperator
3232
from airflow.models.dag import DAG
33+
from airflow.models.dagbundle import DagBundleModel
3334
from airflow.models.serialized_dag import SerializedDagModel
3435
from airflow.models.xcom_arg import XComArg
3536
from airflow.providers.standard.operators.bash import BashOperator
@@ -1771,14 +1772,13 @@ def test_external_task_marker_cyclic_shallow(dag_bag_cyclic):
17711772

17721773

17731774
@pytest.fixture
1774-
def dag_bag_multiple():
1775+
def dag_bag_multiple(session):
17751776
"""
17761777
Create a DagBag containing two DAGs, linked by multiple ExternalTaskMarker.
17771778
"""
17781779
dag_bag = DagBag(dag_folder=DEV_NULL, include_examples=False)
17791780
daily_dag = DAG("daily_dag", start_date=DEFAULT_DATE, schedule="@daily")
17801781
agg_dag = DAG("agg_dag", start_date=DEFAULT_DATE, schedule="@daily")
1781-
17821782
if AIRFLOW_V_3_0_PLUS:
17831783
dag_bag.bag_dag(dag=daily_dag)
17841784
dag_bag.bag_dag(dag=agg_dag)
@@ -1798,6 +1798,12 @@ def dag_bag_multiple():
17981798
dag=agg_dag,
17991799
)
18001800
begin >> task
1801+
bundle_name = "abcbunhdlerch3rc"
1802+
session.merge(DagBundleModel(name=bundle_name))
1803+
session.commit()
1804+
DAG.bulk_write_to_db(bundle_name=bundle_name, dags=[daily_dag, agg_dag], bundle_version=None)
1805+
SerializedDagModel.write_dag(dag=daily_dag, bundle_name=bundle_name)
1806+
SerializedDagModel.write_dag(dag=agg_dag, bundle_name=bundle_name)
18011807

18021808
return dag_bag
18031809

@@ -1819,7 +1825,7 @@ def test_clear_multiple_external_task_marker(dag_bag_multiple):
18191825

18201826

18211827
@pytest.fixture
1822-
def dag_bag_head_tail():
1828+
def dag_bag_head_tail(session):
18231829
"""
18241830
Create a DagBag containing one DAG, with task "head" depending on task "tail" of the
18251831
previous logical_date.
@@ -1858,6 +1864,11 @@ def dag_bag_head_tail():
18581864
dag_bag.bag_dag(dag=dag)
18591865
else:
18601866
dag_bag.bag_dag(dag=dag, root_dag=dag)
1867+
bundle_name = "9e8uh9odhu9c"
1868+
session.merge(DagBundleModel(name=bundle_name))
1869+
session.commit()
1870+
DAG.bulk_write_to_db(bundle_name=bundle_name, dags=[dag], bundle_version=None)
1871+
SerializedDagModel.write_dag(dag=dag, bundle_name=bundle_name)
18611872

18621873
return dag_bag
18631874

0 commit comments

Comments
 (0)