Skip to content

Commit c7cb173

Browse files
github-actions[bot]Lee-W
authored andcommitted
[v3-0-test] fix(task_instances): handle upstream_mapped_index when xcom access is needed (#50641) (#50950)
* fix(task_instances): handle upstream_mapped_index when xcom access is needed * style(expand_input): fix expand_input and SchedulerExpandInput types * test(task_instances): add test_dynamic_task_mapping_with_parse_time_value * test(task_instance): add test_dynamic_task_mapping_with_xcom * style: import typing * style: move the SchedulerExpandInput into type checking block * Revert "style: move the SchedulerExpandInput into type checking block" This reverts commit c2c87ca. (cherry picked from commit 5458e7e) Co-authored-by: Wei Lee <[email protected]>
1 parent 039d1d2 commit c7cb173

File tree

6 files changed

+153
-14
lines changed

6 files changed

+153
-14
lines changed

airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from airflow.models.taskreschedule import TaskReschedule
5858
from airflow.models.trigger import Trigger
5959
from airflow.models.xcom import XComModel
60+
from airflow.sdk.definitions._internal.expandinput import NotFullyPopulated
6061
from airflow.sdk.definitions.taskgroup import MappedTaskGroup
6162
from airflow.utils import timezone
6263
from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState
@@ -244,7 +245,9 @@ def ti_run(
244245
)
245246

246247
if dag := dag_bag.get_dag(ti.dag_id):
247-
upstream_map_indexes = dict(_get_upstream_map_indexes(dag.get_task(ti.task_id), ti.map_index))
248+
upstream_map_indexes = dict(
249+
_get_upstream_map_indexes(dag.get_task(ti.task_id), ti.map_index, ti.run_id, session)
250+
)
248251
else:
249252
upstream_map_indexes = None
250253

@@ -274,7 +277,7 @@ def ti_run(
274277

275278

276279
def _get_upstream_map_indexes(
277-
task: Operator, ti_map_index: int
280+
task: Operator, ti_map_index: int, run_id: str, session: SessionDep
278281
) -> Iterator[tuple[str, int | list[int] | None]]:
279282
for upstream_task in task.upstream_list:
280283
map_indexes: int | list[int] | None
@@ -287,8 +290,17 @@ def _get_upstream_map_indexes(
287290
map_indexes = ti_map_index
288291
else:
289292
# tasks not in the same mapped task group
290-
# the upstream mapped task group should combine the xcom as a list and return it
291-
mapped_ti_count: int = upstream_task.task_group.get_parse_time_mapped_ti_count()
293+
# the upstream mapped task group should combine the return xcom as a list and return it
294+
mapped_ti_count: int
295+
upstream_mapped_group = upstream_task.task_group
296+
try:
297+
# for cases that does not need to resolve xcom
298+
mapped_ti_count = upstream_mapped_group.get_parse_time_mapped_ti_count()
299+
except NotFullyPopulated:
300+
# for cases that needs to resolve xcom to get the correct count
301+
mapped_ti_count = upstream_mapped_group._expand_input.get_total_map_length(
302+
run_id, session=session
303+
)
292304
map_indexes = list(range(mapped_ti_count)) if mapped_ti_count is not None else None
293305

294306
yield upstream_task.task_id, map_indexes

airflow-core/src/airflow/models/expandinput.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import functools
2121
import operator
2222
from collections.abc import Iterable, Sized
23-
from typing import TYPE_CHECKING, Any
23+
from typing import TYPE_CHECKING, Any, ClassVar, Union
2424

2525
import attrs
2626

@@ -32,7 +32,6 @@
3232

3333
from airflow.sdk.definitions._internal.expandinput import (
3434
DictOfListsExpandInput,
35-
ExpandInput,
3635
ListOfDictsExpandInput,
3736
MappedArgument,
3837
NotFullyPopulated,
@@ -62,6 +61,8 @@ def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArg
6261
class SchedulerDictOfListsExpandInput:
6362
value: dict
6463

64+
EXPAND_INPUT_TYPE: ClassVar[str] = "dict-of-lists"
65+
6566
def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
6667
"""Generate kwargs with values available on parse-time."""
6768
return ((k, v) for k, v in self.value.items() if not _needs_run_time_resolution(v))
@@ -114,6 +115,8 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int:
114115
class SchedulerListOfDictsExpandInput:
115116
value: list
116117

118+
EXPAND_INPUT_TYPE: ClassVar[str] = "list-of-dicts"
119+
117120
def get_parse_time_mapped_ti_count(self) -> int:
118121
if isinstance(self.value, Sized):
119122
return len(self.value)
@@ -130,11 +133,13 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int:
130133
return length
131134

132135

133-
_EXPAND_INPUT_TYPES = {
136+
_EXPAND_INPUT_TYPES: dict[str, type[SchedulerExpandInput]] = {
134137
"dict-of-lists": SchedulerDictOfListsExpandInput,
135138
"list-of-dicts": SchedulerListOfDictsExpandInput,
136139
}
137140

141+
SchedulerExpandInput = Union[SchedulerDictOfListsExpandInput, SchedulerListOfDictsExpandInput]
142+
138143

139-
def create_expand_input(kind: str, value: Any) -> ExpandInput:
144+
def create_expand_input(kind: str, value: Any) -> SchedulerExpandInput:
140145
return _EXPAND_INPUT_TYPES[kind](value)

airflow-core/src/airflow/serialization/serialized_objects.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@
100100
from inspect import Parameter
101101

102102
from airflow.models import DagRun
103-
from airflow.models.expandinput import ExpandInput
103+
from airflow.models.expandinput import SchedulerExpandInput
104104
from airflow.sdk import BaseOperatorLink
105105
from airflow.sdk.definitions._internal.node import DAGNode
106106
from airflow.sdk.types import Operator
@@ -577,7 +577,7 @@ def validate_expand_input_value(cls, value: _ExpandInputOriginalValue) -> None:
577577
possible ExpandInput cases.
578578
"""
579579

580-
def deref(self, dag: DAG) -> ExpandInput:
580+
def deref(self, dag: DAG) -> SchedulerExpandInput:
581581
"""
582582
De-reference into a concrete ExpandInput object.
583583

airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from airflow.models.taskinstance import TaskInstance
3434
from airflow.models.taskinstancehistory import TaskInstanceHistory
3535
from airflow.providers.standard.operators.empty import EmptyOperator
36-
from airflow.sdk import TaskGroup
36+
from airflow.sdk import TaskGroup, task, task_group
3737
from airflow.utils import timezone
3838
from airflow.utils.state import State, TaskInstanceState, TerminalTIState
3939

@@ -236,6 +236,128 @@ def test_ti_run_state_to_running(
236236
)
237237
assert response.status_code == 409
238238

239+
def test_dynamic_task_mapping_with_parse_time_value(self, client, dag_maker):
240+
"""
241+
Test that the Task Instance upstream_map_indexes is correctly fetched when to running the Task Instances
242+
"""
243+
244+
with dag_maker("test_dynamic_task_mapping_with_parse_time_value", serialized=True):
245+
246+
@task_group
247+
def task_group_1(arg1):
248+
@task
249+
def group1_task_1(arg1):
250+
return {"a": arg1}
251+
252+
@task
253+
def group1_task_2(arg2):
254+
return arg2
255+
256+
group1_task_2(group1_task_1(arg1))
257+
258+
@task
259+
def task2():
260+
return None
261+
262+
task_group_1.expand(arg1=[0, 1]) >> task2()
263+
264+
dr = dag_maker.create_dagrun()
265+
for ti in dr.get_task_instances():
266+
ti.set_state(State.QUEUED)
267+
dag_maker.session.flush()
268+
269+
# key: (task_id, map_index)
270+
# value: result upstream_map_indexes ({task_id: map_indexes})
271+
expected_upstream_map_indexes = {
272+
# no upstream task for task_group_1.group_task_1
273+
("task_group_1.group1_task_1", 0): {},
274+
("task_group_1.group1_task_1", 1): {},
275+
# the upstream task for task_group_1.group_task_2 is task_group_1.group_task_2
276+
# since they are in the same task group, the upstream map index should be the same as the task
277+
("task_group_1.group1_task_2", 0): {"task_group_1.group1_task_1": 0},
278+
("task_group_1.group1_task_2", 1): {"task_group_1.group1_task_1": 1},
279+
# the upstream task for task2 is the last tasks of task_group_1, which is
280+
# task_group_1.group_task_2
281+
# since they are not in the same task group, the upstream map index should include all the
282+
# expanded tasks
283+
("task2", -1): {"task_group_1.group1_task_2": [0, 1]},
284+
}
285+
286+
for ti in dr.get_task_instances():
287+
response = client.patch(
288+
f"/execution/task-instances/{ti.id}/run",
289+
json={
290+
"state": "running",
291+
"hostname": "random-hostname",
292+
"unixname": "random-unixname",
293+
"pid": 100,
294+
"start_date": "2024-09-30T12:00:00Z",
295+
},
296+
)
297+
298+
assert response.status_code == 200
299+
upstream_map_indexes = response.json()["upstream_map_indexes"]
300+
assert upstream_map_indexes == expected_upstream_map_indexes[(ti.task_id, ti.map_index)]
301+
302+
def test_dynamic_task_mapping_with_xcom(self, client, dag_maker, create_task_instance, session, run_task):
303+
"""
304+
Test that the Task Instance upstream_map_indexes is correctly fetched when to running the Task Instances with xcom
305+
"""
306+
from airflow.models.taskmap import TaskMap
307+
308+
with dag_maker(session=session):
309+
310+
@task
311+
def task_1():
312+
return [0, 1]
313+
314+
@task_group
315+
def tg(x, y):
316+
@task
317+
def task_2():
318+
pass
319+
320+
task_2()
321+
322+
@task
323+
def task_3():
324+
pass
325+
326+
tg.expand(x=task_1(), y=[1, 2, 3]) >> task_3()
327+
328+
dr = dag_maker.create_dagrun()
329+
330+
decision = dr.task_instance_scheduling_decisions(session=session)
331+
332+
# Simulate task_1 execution to produce TaskMap.
333+
(ti_1,) = decision.schedulable_tis
334+
# ti_1 = dr.get_task_instance(task_id="task_1")
335+
ti_1.state = TaskInstanceState.SUCCESS
336+
session.add(TaskMap.from_task_instance_xcom(ti_1, [0, 1]))
337+
session.flush()
338+
339+
# Now task_2 in mapped tagk group is expanded.
340+
decision = dr.task_instance_scheduling_decisions(session=session)
341+
for ti in decision.schedulable_tis:
342+
ti.state = TaskInstanceState.SUCCESS
343+
session.flush()
344+
345+
decision = dr.task_instance_scheduling_decisions(session=session)
346+
(task_3_ti,) = decision.schedulable_tis
347+
task_3_ti.set_state(State.QUEUED)
348+
349+
response = client.patch(
350+
f"/execution/task-instances/{task_3_ti.id}/run",
351+
json={
352+
"state": "running",
353+
"hostname": "random-hostname",
354+
"unixname": "random-unixname",
355+
"pid": 100,
356+
"start_date": "2024-09-30T12:00:00Z",
357+
},
358+
)
359+
assert response.json()["upstream_map_indexes"] == {"tg.task_2": [0, 1, 2, 3, 4, 5]}
360+
239361
def test_next_kwargs_still_encoded(self, client, session, create_task_instance, time_machine):
240362
instant_str = "2024-09-30T12:00:00Z"
241363
instant = timezone.parse(instant_str)

task-sdk/src/airflow/sdk/definitions/mappedoperator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,13 @@
6464
TaskStateChangeCallback,
6565
)
6666
from airflow.models.expandinput import (
67-
ExpandInput,
6867
OperatorExpandArgument,
6968
OperatorExpandKwargsArgument,
7069
)
7170
from airflow.models.xcom_arg import XComArg
7271
from airflow.sdk.bases.operator import BaseOperator
7372
from airflow.sdk.bases.operatorlink import BaseOperatorLink
73+
from airflow.sdk.definitions._internal.expandinput import ExpandInput
7474
from airflow.sdk.definitions.dag import DAG
7575
from airflow.sdk.definitions.param import ParamsDict
7676
from airflow.sdk.types import Operator

task-sdk/src/airflow/sdk/definitions/taskgroup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from airflow.utils.trigger_rule import TriggerRule
4141

4242
if TYPE_CHECKING:
43-
from airflow.models.expandinput import ExpandInput
43+
from airflow.models.expandinput import SchedulerExpandInput
4444
from airflow.sdk.bases.operator import BaseOperator
4545
from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
4646
from airflow.sdk.definitions._internal.mixins import DependencyMixin
@@ -613,7 +613,7 @@ class MappedTaskGroup(TaskGroup):
613613
a ``@task_group`` function instead.
614614
"""
615615

616-
def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None:
616+
def __init__(self, *, expand_input: SchedulerExpandInput, **kwargs: Any) -> None:
617617
super().__init__(**kwargs)
618618
self._expand_input = expand_input
619619

0 commit comments

Comments
 (0)