|
21 | 21 | import functools
|
22 | 22 | import json
|
23 | 23 | import os
|
| 24 | +import textwrap |
24 | 25 | import uuid
|
25 | 26 | from collections.abc import Iterable
|
26 | 27 | from datetime import datetime, timedelta
|
@@ -274,6 +275,54 @@ def test_parse_not_found(test_dags_dir: Path, make_ti_context, dag_id, task_id,
|
274 | 275 | log.error.assert_has_calls([expected_error])
|
275 | 276 |
|
276 | 277 |
|
| 278 | +def test_parse_module_in_bundle_root(tmp_path: Path, make_ti_context): |
| 279 | + """Check that the bundle path is added to sys.path, so Dags can import shared modules.""" |
| 280 | + tmp_path.joinpath("util.py").write_text("NAME = 'dag_name'") |
| 281 | + |
| 282 | + dag1_path = tmp_path.joinpath("path_test.py") |
| 283 | + dag1_code = """ |
| 284 | + from util import NAME |
| 285 | + from airflow.sdk import DAG |
| 286 | + from airflow.sdk.bases.operator import BaseOperator |
| 287 | + with DAG(NAME): |
| 288 | + BaseOperator(task_id="a") |
| 289 | + """ |
| 290 | + dag1_path.write_text(textwrap.dedent(dag1_code)) |
| 291 | + |
| 292 | + what = StartupDetails( |
| 293 | + ti=TaskInstance( |
| 294 | + id=uuid7(), |
| 295 | + task_id="a", |
| 296 | + dag_id="dag_name", |
| 297 | + run_id="c", |
| 298 | + try_number=1, |
| 299 | + ), |
| 300 | + dag_rel_path="path_test.py", |
| 301 | + bundle_info=BundleInfo(name="my-bundle", version=None), |
| 302 | + requests_fd=0, |
| 303 | + ti_context=make_ti_context(), |
| 304 | + start_date=timezone.utcnow(), |
| 305 | + ) |
| 306 | + |
| 307 | + with patch.dict( |
| 308 | + os.environ, |
| 309 | + { |
| 310 | + "AIRFLOW__DAG_PROCESSOR__DAG_BUNDLE_CONFIG_LIST": json.dumps( |
| 311 | + [ |
| 312 | + { |
| 313 | + "name": "my-bundle", |
| 314 | + "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", |
| 315 | + "kwargs": {"path": str(tmp_path), "refresh_interval": 1}, |
| 316 | + } |
| 317 | + ] |
| 318 | + ), |
| 319 | + }, |
| 320 | + ): |
| 321 | + ti = parse(what, mock.Mock()) |
| 322 | + |
| 323 | + assert ti.task.dag.dag_id == "dag_name" |
| 324 | + |
| 325 | + |
277 | 326 | def test_run_deferred_basic(time_machine, create_runtime_ti, mock_supervisor_comms):
|
278 | 327 | """Test that a task can transition to a deferred state."""
|
279 | 328 | from airflow.providers.standard.sensors.date_time import DateTimeSensorAsync
|
|
0 commit comments