Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,15 @@ def structure_data(
elif (
dependency.target == dependency.dependency_type or dependency.source == dag_id
) and exit_node_ref:
end_edges.append({"source_id": exit_node_ref["id"], "target_id": dependency.node_id})
end_edges.append(
{
"source_id": exit_node_ref["id"],
"target_id": dependency.node_id,
"resolved_from_alias": dependency.source.replace("asset-alias:", "", 1)
if dependency.source.startswith("asset-alias:")
else None,
}
)

# Add nodes
nodes.append(
Expand All @@ -142,6 +150,6 @@ def structure_data(

data["edges"] += start_edges + end_edges

bind_output_assets_to_tasks(data["edges"], serialized_dag)
bind_output_assets_to_tasks(data["edges"], serialized_dag, version_number, session)

return StructureDataResponse(**data)
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@

from __future__ import annotations

from collections import defaultdict

from sqlalchemy import select
from sqlalchemy.orm import Session

from airflow.models.asset import AssetAliasModel, AssetEvent
from airflow.models.dag_version import DagVersion
from airflow.models.dagrun import DagRun
from airflow.models.serialized_dag import SerializedDagModel


Expand Down Expand Up @@ -116,30 +124,62 @@ def get_upstream_assets(
return nodes, edges


def bind_output_assets_to_tasks(edges: list[dict], serialized_dag: SerializedDagModel) -> None:
def bind_output_assets_to_tasks(
edges: list[dict], serialized_dag: SerializedDagModel, version_number: int, session: Session
) -> None:
"""
Try to bind the downstream assets to the relevant task that produces them.

This function will mutate the `edges` in place.
"""
# bind normal assets present in the `task_outlet_asset_references`
outlet_asset_references = serialized_dag.dag_model.task_outlet_asset_references

downstream_asset_related_edges = [edge for edge in edges if edge["target_id"].startswith("asset:")]

for edge in downstream_asset_related_edges:
asset_id = int(edge["target_id"].strip("asset:"))
try:
# Try to attach the outlet asset to the relevant task
outlet_asset_reference = next(
outlet_asset_reference
for outlet_asset_reference in outlet_asset_references
if outlet_asset_reference.asset_id == asset_id
)
edge["source_id"] = outlet_asset_reference.task_id
continue
except StopIteration:
# If no asset reference found, fallback to using the exit node reference
# This can happen because asset aliases are not yet handled, they do no populate
# the `outlet_asset_references` when resolved. Extra lookup is needed. Same for asset-name-ref and
# asset-uri-ref.
pass
downstream_asset_edges = [
edge
for edge in edges
if edge["target_id"].startswith("asset:") and not edge.get("resolved_from_alias")
]

for edge in downstream_asset_edges:
# Try to attach the outlet assets to the relevant tasks
asset_id = int(edge["target_id"].replace("asset:", "", 1))
outlet_asset_reference = next(
outlet_asset_reference
for outlet_asset_reference in outlet_asset_references
if outlet_asset_reference.asset_id == asset_id
)
edge["source_id"] = outlet_asset_reference.task_id

# bind assets resolved from aliases, they do not populate the `outlet_asset_references`
downstream_alias_resolved_edges = [
edge for edge in edges if edge["target_id"].startswith("asset:") and edge.get("resolved_from_alias")
]

aliases_names = {edges["resolved_from_alias"] for edges in downstream_alias_resolved_edges}

result = session.scalars(
select(AssetEvent)
.join(AssetEvent.source_aliases)
.join(AssetEvent.source_dag_run)
# That's a simplification, instead doing `version_number` in `DagRun.dag_versions`.
.join(DagRun.created_dag_version)
.where(AssetEvent.source_aliases.any(AssetAliasModel.name.in_(aliases_names)))
.where(AssetEvent.source_dag_run.has(DagRun.dag_id == serialized_dag.dag_model.dag_id))
.where(DagVersion.version_number == version_number)
).unique()

asset_id_to_task_ids = defaultdict(set)
for asset_event in result:
asset_id_to_task_ids[asset_event.asset_id].add(asset_event.source_task_id)

for edge in downstream_alias_resolved_edges:
asset_id = int(edge["target_id"].replace("asset:", "", 1))
task_ids = asset_id_to_task_ids.get(asset_id, set())

for index, task_id in enumerate(task_ids):
if index == 0:
edge["source_id"] = task_id
continue
edge_copy = {**edge, "source_id": task_id}
edges.append(edge_copy)
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,21 @@
from sqlalchemy.orm import Session

from airflow.models import DagBag
from airflow.models.asset import AssetModel
from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.operators.trigger_dagrun import TriggerDagRunOperator
from airflow.providers.standard.sensors.external_task import ExternalTaskSensor
from airflow.sdk import Metadata, task
from airflow.sdk.definitions.asset import Asset, AssetAlias, Dataset
from airflow.utils import timezone

from tests_common.test_utils.db import clear_db_runs
from tests_common.test_utils.db import clear_db_assets, clear_db_runs

pytestmark = pytest.mark.db_test

DAG_ID = "dag_with_multiple_versions"
DAG_ID_EXTERNAL_TRIGGER = "external_trigger"
DAG_ID_RESOLVED_ASSET_ALIAS = "dag_with_resolved_asset_alias"
LATEST_VERSION_DAG_RESPONSE: dict = {
"edges": [],
"nodes": [
Expand Down Expand Up @@ -95,8 +98,10 @@ def examples_dag_bag() -> DagBag:
@pytest.fixture(autouse=True)
def clean():
clear_db_runs()
clear_db_assets()
yield
clear_db_runs()
clear_db_assets()


@pytest.fixture
Expand All @@ -115,15 +120,14 @@ def asset3() -> Dataset:


@pytest.fixture
def make_dag(dag_maker, session, time_machine, asset1: Asset, asset2: Asset, asset3: Dataset) -> None:
def make_dags(dag_maker, session, time_machine, asset1: Asset, asset2: Asset, asset3: Dataset) -> None:
with dag_maker(
dag_id=DAG_ID_EXTERNAL_TRIGGER,
serialized=True,
session=session,
start_date=pendulum.DateTime(2023, 2, 1, 0, 0, 0, tzinfo=pendulum.UTC),
):
TriggerDagRunOperator(task_id="trigger_dag_run_operator", trigger_dag_id=DAG_ID)

dag_maker.sync_dagbag_to_db()

with dag_maker(
Expand All @@ -138,7 +142,45 @@ def make_dag(dag_maker, session, time_machine, asset1: Asset, asset2: Asset, ass
>> ExternalTaskSensor(task_id="external_task_sensor", external_dag_id=DAG_ID)
>> EmptyOperator(task_id="task_2")
)
dag_maker.sync_dagbag_to_db()

with dag_maker(
dag_id=DAG_ID_RESOLVED_ASSET_ALIAS,
serialized=True,
session=session,
start_date=pendulum.DateTime(2023, 2, 1, 0, 0, 0, tzinfo=pendulum.UTC),
):

@task(outlets=[AssetAlias("example-alias-resolved")])
def task_1(**context):
yield Metadata(
asset=Asset("resolved_example_asset_alias"),
extra={"k": "v"}, # extra has to be provided, can be {}
alias=AssetAlias("example-alias-resolved"),
)

task_1() >> EmptyOperator(task_id="task_2")

dr = dag_maker.create_dagrun()
asset_alias = session.scalar(
select(AssetAliasModel).where(AssetAliasModel.name == "example-alias-resolved")
)
asset_model = AssetModel(name="resolved_example_asset_alias")
session.add(asset_model)
session.flush()
asset_alias.assets.append(asset_model)
asset_alias.asset_events.append(
AssetEvent(
id=1,
timestamp=timezone.parse("2021-01-01T00:00:00"),
asset_id=asset_model.id,
source_dag_id=DAG_ID_RESOLVED_ASSET_ALIAS,
source_task_id="task_1",
source_run_id=dr.run_id,
source_map_index=-1,
)
)
session.commit()
dag_maker.sync_dagbag_to_db()


Expand All @@ -151,17 +193,17 @@ def _fetch_asset_id(asset: Asset, session: Session) -> str:


@pytest.fixture
def asset1_id(make_dag, asset1, session: Session) -> str:
def asset1_id(make_dags, asset1, session: Session) -> str:
return _fetch_asset_id(asset1, session)


@pytest.fixture
def asset2_id(make_dag, asset2, session) -> str:
def asset2_id(make_dags, asset2, session) -> str:
return _fetch_asset_id(asset2, session)


@pytest.fixture
def asset3_id(make_dag, asset3, session) -> str:
def asset3_id(make_dags, asset3, session) -> str:
return _fetch_asset_id(asset3, session)


Expand Down Expand Up @@ -296,13 +338,13 @@ class TestStructureDataEndpoint:
),
],
)
@pytest.mark.usefixtures("make_dag")
@pytest.mark.usefixtures("make_dags")
def test_should_return_200(self, test_client, params, expected):
response = test_client.get("/structure/structure_data", params=params)
assert response.status_code == 200
assert response.json() == expected

@pytest.mark.usefixtures("make_dag")
@pytest.mark.usefixtures("make_dags")
def test_should_return_200_with_asset(self, test_client, asset1_id, asset2_id, asset3_id):
params = {
"dag_id": DAG_ID,
Expand Down Expand Up @@ -492,6 +534,75 @@ def test_should_return_200_with_asset(self, test_client, asset1_id, asset2_id, a
assert response.status_code == 200
assert response.json() == expected

@pytest.mark.usefixtures("make_dags")
def test_should_return_200_with_resolved_asset_alias_attached_to_the_corrrect_producing_task(
self, test_client, session
):
resolved_asset = session.scalar(
session.query(AssetModel).filter_by(name="resolved_example_asset_alias")
)
params = {
"dag_id": DAG_ID_RESOLVED_ASSET_ALIAS,
"external_dependencies": True,
}
expected = {
"edges": [
{
"source_id": "task_1",
"target_id": "task_2",
"is_setup_teardown": None,
"label": None,
"is_source_asset": None,
},
{
"source_id": "task_1",
"target_id": f"asset:{resolved_asset.id}",
"is_setup_teardown": None,
"label": None,
"is_source_asset": None,
},
],
"nodes": [
{
"id": "task_1",
"label": "task_1",
"type": "task",
"children": None,
"is_mapped": None,
"tooltip": None,
"setup_teardown_type": None,
"operator": "@task",
"asset_condition_type": None,
},
{
"id": "task_2",
"label": "task_2",
"type": "task",
"children": None,
"is_mapped": None,
"tooltip": None,
"setup_teardown_type": None,
"operator": "EmptyOperator",
"asset_condition_type": None,
},
{
"id": f"asset:{resolved_asset.id}",
"label": "resolved_example_asset_alias",
"type": "asset",
"children": None,
"is_mapped": None,
"tooltip": None,
"setup_teardown_type": None,
"operator": None,
"asset_condition_type": None,
},
],
}

response = test_client.get("/structure/structure_data", params=params)
assert response.status_code == 200
assert response.json() == expected

@pytest.mark.parametrize(
"params, expected",
[
Expand Down
Loading