Skip to content

Commit 16d5bc9

Browse files
github-actions[bot]jedcunninghamamoghrajesh
authored
[v3-0-test] Add bundle path to sys.path in task runner (#51318) (#51341)
(cherry picked from commit b1bbc82) Co-authored-by: Jed Cunningham <[email protected]> Co-authored-by: Amogh Desai <[email protected]>
1 parent 2bccaa3 commit 16d5bc9

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed

task-sdk/src/airflow/sdk/execution_time/task_runner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,11 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance:
562562
)
563563
bundle_instance.initialize()
564564

565+
# Put bundle root on sys.path if needed. This allows the dag bundle to add
566+
# code in util modules to be shared between files within the same bundle.
567+
if (bundle_root := os.fspath(bundle_instance.path)) not in sys.path:
568+
sys.path.append(bundle_root)
569+
565570
dag_absolute_path = os.fspath(Path(bundle_instance.path, what.dag_rel_path))
566571
bag = DagBag(
567572
dag_folder=dag_absolute_path,

task-sdk/tests/task_sdk/execution_time/test_task_runner.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import functools
2222
import json
2323
import os
24+
import textwrap
2425
import uuid
2526
from collections.abc import Iterable
2627
from datetime import datetime, timedelta
@@ -274,6 +275,54 @@ def test_parse_not_found(test_dags_dir: Path, make_ti_context, dag_id, task_id,
274275
log.error.assert_has_calls([expected_error])
275276

276277

278+
def test_parse_module_in_bundle_root(tmp_path: Path, make_ti_context):
279+
"""Check that the bundle path is added to sys.path, so Dags can import shared modules."""
280+
tmp_path.joinpath("util.py").write_text("NAME = 'dag_name'")
281+
282+
dag1_path = tmp_path.joinpath("path_test.py")
283+
dag1_code = """
284+
from util import NAME
285+
from airflow.sdk import DAG
286+
from airflow.sdk.bases.operator import BaseOperator
287+
with DAG(NAME):
288+
BaseOperator(task_id="a")
289+
"""
290+
dag1_path.write_text(textwrap.dedent(dag1_code))
291+
292+
what = StartupDetails(
293+
ti=TaskInstance(
294+
id=uuid7(),
295+
task_id="a",
296+
dag_id="dag_name",
297+
run_id="c",
298+
try_number=1,
299+
),
300+
dag_rel_path="path_test.py",
301+
bundle_info=BundleInfo(name="my-bundle", version=None),
302+
requests_fd=0,
303+
ti_context=make_ti_context(),
304+
start_date=timezone.utcnow(),
305+
)
306+
307+
with patch.dict(
308+
os.environ,
309+
{
310+
"AIRFLOW__DAG_PROCESSOR__DAG_BUNDLE_CONFIG_LIST": json.dumps(
311+
[
312+
{
313+
"name": "my-bundle",
314+
"classpath": "airflow.dag_processing.bundles.local.LocalDagBundle",
315+
"kwargs": {"path": str(tmp_path), "refresh_interval": 1},
316+
}
317+
]
318+
),
319+
},
320+
):
321+
ti = parse(what, mock.Mock())
322+
323+
assert ti.task.dag.dag_id == "dag_name"
324+
325+
277326
def test_run_deferred_basic(time_machine, create_runtime_ti, mock_supervisor_comms):
278327
"""Test that a task can transition to a deferred state."""
279328
from airflow.providers.standard.sensors.date_time import DateTimeSensorAsync

0 commit comments

Comments
 (0)