Skip to content

Commit 09353bf

Browse files
[v3-0-test] Structure endpoint task level resolved aliases (#51481) (#51579)
* Structure endpoint task level resolved aliases * Address PR comments (cherry picked from commit acf1e77) Co-authored-by: Pierre Jeambrun <[email protected]>
1 parent 43df0ac commit 09353bf

File tree

3 files changed

+190
-31
lines changed

3 files changed

+190
-31
lines changed

airflow-core/src/airflow/api_fastapi/core_api/routes/ui/structure.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,15 @@ def structure_data(
122122
elif (
123123
dependency.target == dependency.dependency_type or dependency.source == dag_id
124124
) and exit_node_ref:
125-
end_edges.append({"source_id": exit_node_ref["id"], "target_id": dependency.node_id})
125+
end_edges.append(
126+
{
127+
"source_id": exit_node_ref["id"],
128+
"target_id": dependency.node_id,
129+
"resolved_from_alias": dependency.source.replace("asset-alias:", "", 1)
130+
if dependency.source.startswith("asset-alias:")
131+
else None,
132+
}
133+
)
126134

127135
# Add nodes
128136
nodes.append(
@@ -142,6 +150,6 @@ def structure_data(
142150

143151
data["edges"] += start_edges + end_edges
144152

145-
bind_output_assets_to_tasks(data["edges"], serialized_dag)
153+
bind_output_assets_to_tasks(data["edges"], serialized_dag, version_number, session)
146154

147155
return StructureDataResponse(**data)

airflow-core/src/airflow/api_fastapi/core_api/services/ui/structure.py

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@
2323

2424
from __future__ import annotations
2525

26+
from collections import defaultdict
27+
28+
from sqlalchemy import select
29+
from sqlalchemy.orm import Session
30+
31+
from airflow.models.asset import AssetAliasModel, AssetEvent
32+
from airflow.models.dag_version import DagVersion
33+
from airflow.models.dagrun import DagRun
2634
from airflow.models.serialized_dag import SerializedDagModel
2735

2836

@@ -116,30 +124,62 @@ def get_upstream_assets(
116124
return nodes, edges
117125

118126

119-
def bind_output_assets_to_tasks(edges: list[dict], serialized_dag: SerializedDagModel) -> None:
127+
def bind_output_assets_to_tasks(
128+
edges: list[dict], serialized_dag: SerializedDagModel, version_number: int, session: Session
129+
) -> None:
120130
"""
121131
Try to bind the downstream assets to the relevant task that produces them.
122132
123133
This function will mutate the `edges` in place.
124134
"""
135+
# bind normal assets present in the `task_outlet_asset_references`
125136
outlet_asset_references = serialized_dag.dag_model.task_outlet_asset_references
126137

127-
downstream_asset_related_edges = [edge for edge in edges if edge["target_id"].startswith("asset:")]
128-
129-
for edge in downstream_asset_related_edges:
130-
asset_id = int(edge["target_id"].strip("asset:"))
131-
try:
132-
# Try to attach the outlet asset to the relevant task
133-
outlet_asset_reference = next(
134-
outlet_asset_reference
135-
for outlet_asset_reference in outlet_asset_references
136-
if outlet_asset_reference.asset_id == asset_id
137-
)
138-
edge["source_id"] = outlet_asset_reference.task_id
139-
continue
140-
except StopIteration:
141-
# If no asset reference found, fallback to using the exit node reference
142-
# This can happen because asset aliases are not yet handled, they do no populate
143-
# the `outlet_asset_references` when resolved. Extra lookup is needed. Same for asset-name-ref and
144-
# asset-uri-ref.
145-
pass
138+
downstream_asset_edges = [
139+
edge
140+
for edge in edges
141+
if edge["target_id"].startswith("asset:") and not edge.get("resolved_from_alias")
142+
]
143+
144+
for edge in downstream_asset_edges:
145+
# Try to attach the outlet assets to the relevant tasks
146+
asset_id = int(edge["target_id"].replace("asset:", "", 1))
147+
outlet_asset_reference = next(
148+
outlet_asset_reference
149+
for outlet_asset_reference in outlet_asset_references
150+
if outlet_asset_reference.asset_id == asset_id
151+
)
152+
edge["source_id"] = outlet_asset_reference.task_id
153+
154+
# bind assets resolved from aliases, they do not populate the `outlet_asset_references`
155+
downstream_alias_resolved_edges = [
156+
edge for edge in edges if edge["target_id"].startswith("asset:") and edge.get("resolved_from_alias")
157+
]
158+
159+
aliases_names = {edges["resolved_from_alias"] for edges in downstream_alias_resolved_edges}
160+
161+
result = session.scalars(
162+
select(AssetEvent)
163+
.join(AssetEvent.source_aliases)
164+
.join(AssetEvent.source_dag_run)
165+
# That's a simplification, instead doing `version_number` in `DagRun.dag_versions`.
166+
.join(DagRun.created_dag_version)
167+
.where(AssetEvent.source_aliases.any(AssetAliasModel.name.in_(aliases_names)))
168+
.where(AssetEvent.source_dag_run.has(DagRun.dag_id == serialized_dag.dag_model.dag_id))
169+
.where(DagVersion.version_number == version_number)
170+
).unique()
171+
172+
asset_id_to_task_ids = defaultdict(set)
173+
for asset_event in result:
174+
asset_id_to_task_ids[asset_event.asset_id].add(asset_event.source_task_id)
175+
176+
for edge in downstream_alias_resolved_edges:
177+
asset_id = int(edge["target_id"].replace("asset:", "", 1))
178+
task_ids = asset_id_to_task_ids.get(asset_id, set())
179+
180+
for index, task_id in enumerate(task_ids):
181+
if index == 0:
182+
edge["source_id"] = task_id
183+
continue
184+
edge_copy = {**edge, "source_id": task_id}
185+
edges.append(edge_copy)

airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_structure.py

Lines changed: 120 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,21 @@
2525
from sqlalchemy.orm import Session
2626

2727
from airflow.models import DagBag
28-
from airflow.models.asset import AssetModel
28+
from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel
2929
from airflow.providers.standard.operators.empty import EmptyOperator
3030
from airflow.providers.standard.operators.trigger_dagrun import TriggerDagRunOperator
3131
from airflow.providers.standard.sensors.external_task import ExternalTaskSensor
32+
from airflow.sdk import Metadata, task
3233
from airflow.sdk.definitions.asset import Asset, AssetAlias, Dataset
34+
from airflow.utils import timezone
3335

34-
from tests_common.test_utils.db import clear_db_runs
36+
from tests_common.test_utils.db import clear_db_assets, clear_db_runs
3537

3638
pytestmark = pytest.mark.db_test
3739

3840
DAG_ID = "dag_with_multiple_versions"
3941
DAG_ID_EXTERNAL_TRIGGER = "external_trigger"
42+
DAG_ID_RESOLVED_ASSET_ALIAS = "dag_with_resolved_asset_alias"
4043
LATEST_VERSION_DAG_RESPONSE: dict = {
4144
"edges": [],
4245
"nodes": [
@@ -95,8 +98,10 @@ def examples_dag_bag() -> DagBag:
9598
@pytest.fixture(autouse=True)
9699
def clean():
97100
clear_db_runs()
101+
clear_db_assets()
98102
yield
99103
clear_db_runs()
104+
clear_db_assets()
100105

101106

102107
@pytest.fixture
@@ -115,15 +120,14 @@ def asset3() -> Dataset:
115120

116121

117122
@pytest.fixture
118-
def make_dag(dag_maker, session, time_machine, asset1: Asset, asset2: Asset, asset3: Dataset) -> None:
123+
def make_dags(dag_maker, session, time_machine, asset1: Asset, asset2: Asset, asset3: Dataset) -> None:
119124
with dag_maker(
120125
dag_id=DAG_ID_EXTERNAL_TRIGGER,
121126
serialized=True,
122127
session=session,
123128
start_date=pendulum.DateTime(2023, 2, 1, 0, 0, 0, tzinfo=pendulum.UTC),
124129
):
125130
TriggerDagRunOperator(task_id="trigger_dag_run_operator", trigger_dag_id=DAG_ID)
126-
127131
dag_maker.sync_dagbag_to_db()
128132

129133
with dag_maker(
@@ -138,7 +142,45 @@ def make_dag(dag_maker, session, time_machine, asset1: Asset, asset2: Asset, ass
138142
>> ExternalTaskSensor(task_id="external_task_sensor", external_dag_id=DAG_ID)
139143
>> EmptyOperator(task_id="task_2")
140144
)
145+
dag_maker.sync_dagbag_to_db()
146+
147+
with dag_maker(
148+
dag_id=DAG_ID_RESOLVED_ASSET_ALIAS,
149+
serialized=True,
150+
session=session,
151+
start_date=pendulum.DateTime(2023, 2, 1, 0, 0, 0, tzinfo=pendulum.UTC),
152+
):
153+
154+
@task(outlets=[AssetAlias("example-alias-resolved")])
155+
def task_1(**context):
156+
yield Metadata(
157+
asset=Asset("resolved_example_asset_alias"),
158+
extra={"k": "v"}, # extra has to be provided, can be {}
159+
alias=AssetAlias("example-alias-resolved"),
160+
)
141161

162+
task_1() >> EmptyOperator(task_id="task_2")
163+
164+
dr = dag_maker.create_dagrun()
165+
asset_alias = session.scalar(
166+
select(AssetAliasModel).where(AssetAliasModel.name == "example-alias-resolved")
167+
)
168+
asset_model = AssetModel(name="resolved_example_asset_alias")
169+
session.add(asset_model)
170+
session.flush()
171+
asset_alias.assets.append(asset_model)
172+
asset_alias.asset_events.append(
173+
AssetEvent(
174+
id=1,
175+
timestamp=timezone.parse("2021-01-01T00:00:00"),
176+
asset_id=asset_model.id,
177+
source_dag_id=DAG_ID_RESOLVED_ASSET_ALIAS,
178+
source_task_id="task_1",
179+
source_run_id=dr.run_id,
180+
source_map_index=-1,
181+
)
182+
)
183+
session.commit()
142184
dag_maker.sync_dagbag_to_db()
143185

144186

@@ -151,17 +193,17 @@ def _fetch_asset_id(asset: Asset, session: Session) -> str:
151193

152194

153195
@pytest.fixture
154-
def asset1_id(make_dag, asset1, session: Session) -> str:
196+
def asset1_id(make_dags, asset1, session: Session) -> str:
155197
return _fetch_asset_id(asset1, session)
156198

157199

158200
@pytest.fixture
159-
def asset2_id(make_dag, asset2, session) -> str:
201+
def asset2_id(make_dags, asset2, session) -> str:
160202
return _fetch_asset_id(asset2, session)
161203

162204

163205
@pytest.fixture
164-
def asset3_id(make_dag, asset3, session) -> str:
206+
def asset3_id(make_dags, asset3, session) -> str:
165207
return _fetch_asset_id(asset3, session)
166208

167209

@@ -296,13 +338,13 @@ class TestStructureDataEndpoint:
296338
),
297339
],
298340
)
299-
@pytest.mark.usefixtures("make_dag")
341+
@pytest.mark.usefixtures("make_dags")
300342
def test_should_return_200(self, test_client, params, expected):
301343
response = test_client.get("/structure/structure_data", params=params)
302344
assert response.status_code == 200
303345
assert response.json() == expected
304346

305-
@pytest.mark.usefixtures("make_dag")
347+
@pytest.mark.usefixtures("make_dags")
306348
def test_should_return_200_with_asset(self, test_client, asset1_id, asset2_id, asset3_id):
307349
params = {
308350
"dag_id": DAG_ID,
@@ -492,6 +534,75 @@ def test_should_return_200_with_asset(self, test_client, asset1_id, asset2_id, a
492534
assert response.status_code == 200
493535
assert response.json() == expected
494536

537+
@pytest.mark.usefixtures("make_dags")
538+
def test_should_return_200_with_resolved_asset_alias_attached_to_the_corrrect_producing_task(
539+
self, test_client, session
540+
):
541+
resolved_asset = session.scalar(
542+
session.query(AssetModel).filter_by(name="resolved_example_asset_alias")
543+
)
544+
params = {
545+
"dag_id": DAG_ID_RESOLVED_ASSET_ALIAS,
546+
"external_dependencies": True,
547+
}
548+
expected = {
549+
"edges": [
550+
{
551+
"source_id": "task_1",
552+
"target_id": "task_2",
553+
"is_setup_teardown": None,
554+
"label": None,
555+
"is_source_asset": None,
556+
},
557+
{
558+
"source_id": "task_1",
559+
"target_id": f"asset:{resolved_asset.id}",
560+
"is_setup_teardown": None,
561+
"label": None,
562+
"is_source_asset": None,
563+
},
564+
],
565+
"nodes": [
566+
{
567+
"id": "task_1",
568+
"label": "task_1",
569+
"type": "task",
570+
"children": None,
571+
"is_mapped": None,
572+
"tooltip": None,
573+
"setup_teardown_type": None,
574+
"operator": "@task",
575+
"asset_condition_type": None,
576+
},
577+
{
578+
"id": "task_2",
579+
"label": "task_2",
580+
"type": "task",
581+
"children": None,
582+
"is_mapped": None,
583+
"tooltip": None,
584+
"setup_teardown_type": None,
585+
"operator": "EmptyOperator",
586+
"asset_condition_type": None,
587+
},
588+
{
589+
"id": f"asset:{resolved_asset.id}",
590+
"label": "resolved_example_asset_alias",
591+
"type": "asset",
592+
"children": None,
593+
"is_mapped": None,
594+
"tooltip": None,
595+
"setup_teardown_type": None,
596+
"operator": None,
597+
"asset_condition_type": None,
598+
},
599+
],
600+
}
601+
602+
response = test_client.get("/structure/structure_data", params=params)
603+
assert response.status_code == 200
604+
assert response.json() == expected
605+
495606
@pytest.mark.parametrize(
496607
"params, expected",
497608
[

0 commit comments

Comments
 (0)