Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
from __future__ import annotations

import contextlib
from functools import cache
from operator import methodcaller
from typing import Callable
from uuid import UUID

import structlog
Expand All @@ -38,7 +35,6 @@
from airflow.api_fastapi.core_api.datamodels.ui.structure import (
StructureDataResponse,
)
from airflow.configuration import conf
from airflow.models.baseoperator import BaseOperator as DBBaseOperator
from airflow.models.dag_version import DagVersion
from airflow.models.taskmap import TaskMap
Expand All @@ -49,20 +45,11 @@
from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.utils.state import TaskInstanceState
from airflow.utils.task_group import task_group_to_dict
from airflow.utils.task_group import get_task_group_children_getter, task_group_to_dict

log = structlog.get_logger(logger_name=__name__)


@cache
def get_task_group_children_getter() -> Callable:
"""Get the Task Group Children Getter for the DAG."""
sort_order = conf.get("webserver", "grid_view_sorting_order")
if sort_order == "topological":
return methodcaller("topological_sort")
return methodcaller("hierarchical_alphabetical_sort")


def get_task_group_map(dag: DAG) -> dict[str, dict[str, Any]]:
"""
Get the Task Group Map for the DAG.
Expand Down Expand Up @@ -262,7 +249,7 @@ def fill_task_instance_summaries(

def get_structure_from_dag(dag: DAG) -> StructureDataResponse:
"""If we do not have TIs, we just get the structure from the DAG."""
nodes = [task_group_to_dict(child) for child in dag.task_group.topological_sort()]
nodes = [task_group_to_dict(child) for child in get_task_group_children_getter()(dag.task_group)]
return StructureDataResponse(nodes=nodes, edges=[])


Expand Down Expand Up @@ -299,7 +286,7 @@ def get_combined_structure(task_instances, session):
if serdag:
dags.append(serdag.dag)
for dag in dags:
nodes = [task_group_to_dict(child) for child in dag.task_group.topological_sort()]
nodes = [task_group_to_dict(child) for child in get_task_group_children_getter()(dag.task_group)]
_merge_node_dicts(merged_nodes, nodes)

return StructureDataResponse(nodes=merged_nodes, edges=[])
Expand Down
16 changes: 14 additions & 2 deletions airflow-core/src/airflow/utils/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from functools import cache
from operator import methodcaller
from typing import TYPE_CHECKING, Callable

import airflow.sdk.definitions.taskgroup
from airflow.configuration import conf

if TYPE_CHECKING:
from airflow.typing_compat import TypeAlias
Expand All @@ -30,6 +33,15 @@
MappedTaskGroup: TypeAlias = airflow.sdk.definitions.taskgroup.MappedTaskGroup


@cache
def get_task_group_children_getter() -> Callable:
"""Get the Task Group Children Getter for the DAG."""
sort_order = conf.get("webserver", "grid_view_sorting_order")
if sort_order == "topological":
return methodcaller("topological_sort")
return methodcaller("hierarchical_alphabetical_sort")


def task_group_to_dict(task_item_or_group, parent_group_is_mapped=False):
"""Create a nested dict representation of this TaskGroup and its children used to construct the Graph."""
from airflow.sdk.bases.operator import BaseOperator
Expand Down Expand Up @@ -63,7 +75,7 @@ def task_group_to_dict(task_item_or_group, parent_group_is_mapped=False):
is_mapped = isinstance(task_group, MappedTaskGroup)
children = [
task_group_to_dict(child, parent_group_is_mapped=parent_group_is_mapped or is_mapped)
for child in sorted(task_group.children.values(), key=lambda t: t.label)
for child in get_task_group_children_getter()(task_group)
]

if task_group.upstream_group_ids or task_group.upstream_task_ids:
Expand Down
72 changes: 36 additions & 36 deletions airflow-core/tests/unit/utils/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ def my_task():
"tooltip": "",
},
"children": [
{
"id": "task1",
"value": {
"label": "task1",
"labelStyle": "fill:#000;",
"style": "fill:#e8f7e4;",
"rx": 5,
"ry": 5,
},
},
{
"id": "group234",
"value": {
Expand All @@ -78,6 +88,16 @@ def my_task():
"isMapped": False,
},
"children": [
{
"id": "group234.task2",
"value": {
"label": "task2",
"labelStyle": "fill:#000;",
"style": "fill:#e8f7e4;",
"rx": 5,
"ry": 5,
},
},
{
"id": "group234.group34",
"value": {
Expand Down Expand Up @@ -122,16 +142,6 @@ def my_task():
},
],
},
{
"id": "group234.task2",
"value": {
"label": "task2",
"labelStyle": "fill:#000;",
"style": "fill:#e8f7e4;",
"rx": 5,
"ry": 5,
},
},
{
"id": "group234.upstream_join_id",
"value": {
Expand All @@ -143,16 +153,6 @@ def my_task():
},
],
},
{
"id": "task1",
"value": {
"label": "task1",
"labelStyle": "fill:#000;",
"style": "fill:#e8f7e4;",
"rx": 5,
"ry": 5,
},
},
{
"id": "task5",
"value": {
Expand All @@ -172,12 +172,14 @@ def my_task():
"tooltip": "",
"is_mapped": False,
"children": [
{"id": "task1", "label": "task1", "operator": "EmptyOperator", "type": "task"},
{
"id": "group234",
"label": "group234",
"tooltip": "",
"is_mapped": False,
"children": [
{"id": "group234.task2", "label": "task2", "operator": "EmptyOperator", "type": "task"},
{
"id": "group234.group34",
"label": "group34",
Expand All @@ -200,12 +202,10 @@ def my_task():
],
"type": "task",
},
{"id": "group234.task2", "label": "task2", "operator": "EmptyOperator", "type": "task"},
{"id": "group234.upstream_join_id", "label": "", "type": "join"},
],
"type": "task",
},
{"id": "task1", "label": "task1", "operator": "EmptyOperator", "type": "task"},
{"id": "task5", "label": "task5", "operator": "EmptyOperator", "type": "task"},
],
"type": "task",
Expand Down Expand Up @@ -314,28 +314,28 @@ def test_build_task_group_with_prefix():
"id": None,
"label": None,
"children": [
{"id": "task1", "label": "task1"},
{
"id": "group234",
"label": "group234",
"children": [
{"id": "task2", "label": "task2"},
{
"id": "group34",
"label": "group34",
"children": [
{"id": "group34.task3", "label": "task3"},
{
"id": "group34.group4",
"label": "group4",
"children": [{"id": "task4", "label": "task4"}],
},
{"id": "group34.task3", "label": "task3"},
{"id": "group34.downstream_join_id", "label": ""},
],
},
{"id": "task2", "label": "task2"},
{"id": "group234.upstream_join_id", "label": ""},
],
},
{"id": "task1", "label": "task1"},
{"id": "task5", "label": "task5"},
],
}
Expand Down Expand Up @@ -389,6 +389,7 @@ def task_5():
expected_node_id = {
"id": None,
"children": [
{"id": "task_1"},
{
"id": "group234",
"children": [
Expand All @@ -399,7 +400,6 @@ def task_5():
{"id": "group234.downstream_join_id"},
],
},
{"id": "task_1"},
{"id": "task_5"},
],
}
Expand Down Expand Up @@ -448,6 +448,7 @@ def test_sub_dag_task_group():
expected_node_id = {
"id": None,
"children": [
{"id": "task1"},
{
"id": "group234",
"children": [
Expand All @@ -462,7 +463,6 @@ def test_sub_dag_task_group():
{"id": "group234.upstream_join_id"},
],
},
{"id": "task1"},
{"id": "task5"},
],
}
Expand Down Expand Up @@ -540,6 +540,7 @@ def test_dag_edges():
expected_node_id = {
"id": None,
"children": [
{"id": "task1"},
{
"id": "group_a",
"children": [
Expand Down Expand Up @@ -567,6 +568,8 @@ def test_dag_edges():
{"id": "group_c.downstream_join_id"},
],
},
{"id": "task9"},
{"id": "task10"},
{
"id": "group_d",
"children": [
Expand All @@ -575,9 +578,6 @@ def test_dag_edges():
{"id": "group_d.upstream_join_id"},
],
},
{"id": "task1"},
{"id": "task10"},
{"id": "task9"},
],
}

Expand Down Expand Up @@ -818,22 +818,22 @@ def section_2(value2):
node_ids = {
"id": None,
"children": [
{"id": "task_start"},
{
"id": "section_1",
"children": [
{"id": "section_1.task_1"},
{"id": "section_1.task_2"},
{
"id": "section_1.section_2",
"children": [
{"id": "section_1.section_2.task_3"},
{"id": "section_1.section_2.task_4"},
],
},
{"id": "section_1.task_1"},
{"id": "section_1.task_2"},
],
},
{"id": "task_end"},
{"id": "task_start"},
],
}

Expand Down Expand Up @@ -992,6 +992,7 @@ def section_2(value):
node_ids = {
"id": None,
"children": [
{"id": "task_start"},
{
"id": "section_1",
"children": [
Expand All @@ -1011,7 +1012,6 @@ def section_2(value):
],
},
{"id": "task_end"},
{"id": "task_start"},
],
}

Expand Down Expand Up @@ -1153,17 +1153,17 @@ def task_group1(name: str):
{
"id": "task_group1",
"children": [
{"id": "task_group1.end_task"},
{"id": "task_group1.start_task"},
{"id": "task_group1.task"},
{"id": "task_group1.end_task"},
],
},
{
"id": "task_group1__1",
"children": [
{"id": "task_group1__1.end_task"},
{"id": "task_group1__1.start_task"},
{"id": "task_group1__1.task"},
{"id": "task_group1__1.end_task"},
],
},
],
Expand Down