Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.trigger import Trigger
from airflow.models.xcom import XComModel
from airflow.sdk.definitions._internal.expandinput import NotFullyPopulated
from airflow.sdk.definitions.taskgroup import MappedTaskGroup
from airflow.utils import timezone
from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState
Expand Down Expand Up @@ -244,7 +245,9 @@ def ti_run(
)

if dag := dag_bag.get_dag(ti.dag_id):
upstream_map_indexes = dict(_get_upstream_map_indexes(dag.get_task(ti.task_id), ti.map_index))
upstream_map_indexes = dict(
_get_upstream_map_indexes(dag.get_task(ti.task_id), ti.map_index, ti.run_id, session)
)
else:
upstream_map_indexes = None

Expand Down Expand Up @@ -274,7 +277,7 @@ def ti_run(


def _get_upstream_map_indexes(
task: Operator, ti_map_index: int
task: Operator, ti_map_index: int, run_id: str, session: SessionDep
) -> Iterator[tuple[str, int | list[int] | None]]:
for upstream_task in task.upstream_list:
map_indexes: int | list[int] | None
Expand All @@ -287,8 +290,17 @@ def _get_upstream_map_indexes(
map_indexes = ti_map_index
else:
# tasks not in the same mapped task group
# the upstream mapped task group should combine the xcom as a list and return it
mapped_ti_count: int = upstream_task.task_group.get_parse_time_mapped_ti_count()
# the upstream mapped task group should combine the return xcom as a list and return it
mapped_ti_count: int
upstream_mapped_group = upstream_task.task_group
try:
# for cases that does not need to resolve xcom
mapped_ti_count = upstream_mapped_group.get_parse_time_mapped_ti_count()
except NotFullyPopulated:
# for cases that needs to resolve xcom to get the correct count
mapped_ti_count = upstream_mapped_group._expand_input.get_total_map_length(
run_id, session=session
)
map_indexes = list(range(mapped_ti_count)) if mapped_ti_count is not None else None

yield upstream_task.task_id, map_indexes
Expand Down
13 changes: 9 additions & 4 deletions airflow-core/src/airflow/models/expandinput.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import functools
import operator
from collections.abc import Iterable, Sized
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, ClassVar, Union

import attrs

Expand All @@ -32,7 +32,6 @@

from airflow.sdk.definitions._internal.expandinput import (
DictOfListsExpandInput,
ExpandInput,
ListOfDictsExpandInput,
MappedArgument,
NotFullyPopulated,
Expand Down Expand Up @@ -62,6 +61,8 @@ def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArg
class SchedulerDictOfListsExpandInput:
value: dict

EXPAND_INPUT_TYPE: ClassVar[str] = "dict-of-lists"

def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
"""Generate kwargs with values available on parse-time."""
return ((k, v) for k, v in self.value.items() if not _needs_run_time_resolution(v))
Expand Down Expand Up @@ -114,6 +115,8 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int:
class SchedulerListOfDictsExpandInput:
value: list

EXPAND_INPUT_TYPE: ClassVar[str] = "list-of-dicts"

def get_parse_time_mapped_ti_count(self) -> int:
if isinstance(self.value, Sized):
return len(self.value)
Expand All @@ -130,11 +133,13 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int:
return length


_EXPAND_INPUT_TYPES = {
_EXPAND_INPUT_TYPES: dict[str, type[SchedulerExpandInput]] = {
"dict-of-lists": SchedulerDictOfListsExpandInput,
"list-of-dicts": SchedulerListOfDictsExpandInput,
}

SchedulerExpandInput = Union[SchedulerDictOfListsExpandInput, SchedulerListOfDictsExpandInput]


def create_expand_input(kind: str, value: Any) -> ExpandInput:
def create_expand_input(kind: str, value: Any) -> SchedulerExpandInput:
return _EXPAND_INPUT_TYPES[kind](value)
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
from inspect import Parameter

from airflow.models import DagRun
from airflow.models.expandinput import ExpandInput
from airflow.models.expandinput import SchedulerExpandInput
from airflow.sdk import BaseOperatorLink
from airflow.sdk.definitions._internal.node import DAGNode
from airflow.sdk.types import Operator
Expand Down Expand Up @@ -577,7 +577,7 @@ def validate_expand_input_value(cls, value: _ExpandInputOriginalValue) -> None:
possible ExpandInput cases.
"""

def deref(self, dag: DAG) -> ExpandInput:
def deref(self, dag: DAG) -> SchedulerExpandInput:
"""
De-reference into a concrete ExpandInput object.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancehistory import TaskInstanceHistory
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.sdk import TaskGroup
from airflow.sdk import TaskGroup, task, task_group
from airflow.utils import timezone
from airflow.utils.state import State, TaskInstanceState, TerminalTIState

Expand Down Expand Up @@ -237,6 +237,128 @@ def test_ti_run_state_to_running(
)
assert response.status_code == 409

def test_dynamic_task_mapping_with_parse_time_value(self, client, dag_maker):
"""
Test that the Task Instance upstream_map_indexes is correctly fetched when to running the Task Instances
"""

with dag_maker("test_dynamic_task_mapping_with_parse_time_value", serialized=True):

@task_group
def task_group_1(arg1):
@task
def group1_task_1(arg1):
return {"a": arg1}

@task
def group1_task_2(arg2):
return arg2

group1_task_2(group1_task_1(arg1))

@task
def task2():
return None

task_group_1.expand(arg1=[0, 1]) >> task2()

dr = dag_maker.create_dagrun()
for ti in dr.get_task_instances():
ti.set_state(State.QUEUED)
dag_maker.session.flush()

# key: (task_id, map_index)
# value: result upstream_map_indexes ({task_id: map_indexes})
expected_upstream_map_indexes = {
# no upstream task for task_group_1.group_task_1
("task_group_1.group1_task_1", 0): {},
("task_group_1.group1_task_1", 1): {},
# the upstream task for task_group_1.group_task_2 is task_group_1.group_task_2
# since they are in the same task group, the upstream map index should be the same as the task
("task_group_1.group1_task_2", 0): {"task_group_1.group1_task_1": 0},
("task_group_1.group1_task_2", 1): {"task_group_1.group1_task_1": 1},
# the upstream task for task2 is the last tasks of task_group_1, which is
# task_group_1.group_task_2
# since they are not in the same task group, the upstream map index should include all the
# expanded tasks
("task2", -1): {"task_group_1.group1_task_2": [0, 1]},
}

for ti in dr.get_task_instances():
response = client.patch(
f"/execution/task-instances/{ti.id}/run",
json={
"state": "running",
"hostname": "random-hostname",
"unixname": "random-unixname",
"pid": 100,
"start_date": "2024-09-30T12:00:00Z",
},
)

assert response.status_code == 200
upstream_map_indexes = response.json()["upstream_map_indexes"]
assert upstream_map_indexes == expected_upstream_map_indexes[(ti.task_id, ti.map_index)]

def test_dynamic_task_mapping_with_xcom(self, client, dag_maker, create_task_instance, session, run_task):
"""
Test that the Task Instance upstream_map_indexes is correctly fetched when to running the Task Instances with xcom
"""
from airflow.models.taskmap import TaskMap

with dag_maker(session=session):

@task
def task_1():
return [0, 1]

@task_group
def tg(x, y):
@task
def task_2():
pass

task_2()

@task
def task_3():
pass

tg.expand(x=task_1(), y=[1, 2, 3]) >> task_3()

dr = dag_maker.create_dagrun()

decision = dr.task_instance_scheduling_decisions(session=session)

# Simulate task_1 execution to produce TaskMap.
(ti_1,) = decision.schedulable_tis
# ti_1 = dr.get_task_instance(task_id="task_1")
ti_1.state = TaskInstanceState.SUCCESS
session.add(TaskMap.from_task_instance_xcom(ti_1, [0, 1]))
session.flush()

# Now task_2 in mapped tagk group is expanded.
decision = dr.task_instance_scheduling_decisions(session=session)
for ti in decision.schedulable_tis:
ti.state = TaskInstanceState.SUCCESS
session.flush()

decision = dr.task_instance_scheduling_decisions(session=session)
(task_3_ti,) = decision.schedulable_tis
task_3_ti.set_state(State.QUEUED)

response = client.patch(
f"/execution/task-instances/{task_3_ti.id}/run",
json={
"state": "running",
"hostname": "random-hostname",
"unixname": "random-unixname",
"pid": 100,
"start_date": "2024-09-30T12:00:00Z",
},
)
assert response.json()["upstream_map_indexes"] == {"tg.task_2": [0, 1, 2, 3, 4, 5]}

def test_next_kwargs_still_encoded(self, client, session, create_task_instance, time_machine):
instant_str = "2024-09-30T12:00:00Z"
instant = timezone.parse(instant_str)
Expand Down
2 changes: 1 addition & 1 deletion task-sdk/src/airflow/sdk/definitions/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@
TaskStateChangeCallback,
)
from airflow.models.expandinput import (
ExpandInput,
OperatorExpandArgument,
OperatorExpandKwargsArgument,
)
from airflow.models.xcom_arg import XComArg
from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.bases.operatorlink import BaseOperatorLink
from airflow.sdk.definitions._internal.expandinput import ExpandInput
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.param import ParamsDict
from airflow.sdk.types import Operator
Expand Down
4 changes: 2 additions & 2 deletions task-sdk/src/airflow/sdk/definitions/taskgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from airflow.utils.trigger_rule import TriggerRule

if TYPE_CHECKING:
from airflow.models.expandinput import ExpandInput
from airflow.models.expandinput import SchedulerExpandInput
from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
from airflow.sdk.definitions._internal.mixins import DependencyMixin
Expand Down Expand Up @@ -613,7 +613,7 @@ class MappedTaskGroup(TaskGroup):
a ``@task_group`` function instead.
"""

def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None:
def __init__(self, *, expand_input: SchedulerExpandInput, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._expand_input = expand_input

Expand Down