Skip to content

Commit 794951e

Browse files
jedcunninghamsanederchik
authored andcommitted
Add bundle path to sys.path in task runner (apache#51318)
1 parent 849ad13 commit 794951e

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed

task-sdk/src/airflow/sdk/execution_time/task_runner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,11 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance:
582582
)
583583
bundle_instance.initialize()
584584

585+
# Put bundle root on sys.path if needed. This allows the dag bundle to add
586+
# code in util modules to be shared between files within the same bundle.
587+
if (bundle_root := os.fspath(bundle_instance.path)) not in sys.path:
588+
sys.path.append(bundle_root)
589+
585590
dag_absolute_path = os.fspath(Path(bundle_instance.path, what.dag_rel_path))
586591
bag = DagBag(
587592
dag_folder=dag_absolute_path,

task-sdk/tests/task_sdk/execution_time/test_task_runner.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import functools
2222
import json
2323
import os
24+
import textwrap
2425
import uuid
2526
from collections.abc import Iterable
2627
from datetime import datetime, timedelta
@@ -275,6 +276,54 @@ def test_parse_not_found(test_dags_dir: Path, make_ti_context, dag_id, task_id,
275276
log.error.assert_has_calls([expected_error])
276277

277278

279+
def test_parse_module_in_bundle_root(tmp_path: Path, make_ti_context):
280+
"""Check that the bundle path is added to sys.path, so Dags can import shared modules."""
281+
tmp_path.joinpath("util.py").write_text("NAME = 'dag_name'")
282+
283+
dag1_path = tmp_path.joinpath("path_test.py")
284+
dag1_code = """
285+
from util import NAME
286+
from airflow.sdk import DAG
287+
from airflow.sdk.bases.operator import BaseOperator
288+
with DAG(NAME):
289+
BaseOperator(task_id="a")
290+
"""
291+
dag1_path.write_text(textwrap.dedent(dag1_code))
292+
293+
what = StartupDetails(
294+
ti=TaskInstance(
295+
id=uuid7(),
296+
task_id="a",
297+
dag_id="dag_name",
298+
run_id="c",
299+
try_number=1,
300+
),
301+
dag_rel_path="path_test.py",
302+
bundle_info=BundleInfo(name="my-bundle", version=None),
303+
requests_fd=0,
304+
ti_context=make_ti_context(),
305+
start_date=timezone.utcnow(),
306+
)
307+
308+
with patch.dict(
309+
os.environ,
310+
{
311+
"AIRFLOW__DAG_PROCESSOR__DAG_BUNDLE_CONFIG_LIST": json.dumps(
312+
[
313+
{
314+
"name": "my-bundle",
315+
"classpath": "airflow.dag_processing.bundles.local.LocalDagBundle",
316+
"kwargs": {"path": str(tmp_path), "refresh_interval": 1},
317+
}
318+
]
319+
),
320+
},
321+
):
322+
ti = parse(what, mock.Mock())
323+
324+
assert ti.task.dag.dag_id == "dag_name"
325+
326+
278327
def test_run_deferred_basic(time_machine, create_runtime_ti, mock_supervisor_comms):
279328
"""Test that a task can transition to a deferred state."""
280329
from airflow.providers.standard.sensors.date_time import DateTimeSensorAsync

0 commit comments

Comments
 (0)