From 02915c9ce4199456c55cc59754857ebc431053fc Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 22 May 2025 18:53:39 +0800 Subject: [PATCH] [v3-0-test] fix(task_instances): handle upstream_mapped_index when xcom access is needed (#50641) * fix(task_instances): handle upstream_mapped_index when xcom access is needed * style(expand_input): fix expand_input and SchedulerExpandInput types * test(task_instances): add test_dynamic_task_mapping_with_parse_time_value * test(task_instance): add test_dynamic_task_mapping_with_xcom * style: import typing * style: move the SchedulerExpandInput into type checking block * Revert "style: move the SchedulerExpandInput into type checking block" This reverts commit c2c87ca304bfe721120bda19f3dcc3a0ddab8804. (cherry picked from commit 5458e7e7be86c6de034d7a589bd26db85c532308) Co-authored-by: Wei Lee --- .../execution_api/routes/task_instances.py | 20 ++- .../src/airflow/models/expandinput.py | 13 +- .../serialization/serialized_objects.py | 4 +- .../versions/head/test_task_instances.py | 124 +++++++++++++++++- .../airflow/sdk/definitions/mappedoperator.py | 2 +- .../src/airflow/sdk/definitions/taskgroup.py | 4 +- 6 files changed, 153 insertions(+), 14 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index a48070deb26c5..ac1d1602460de 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -57,6 +57,7 @@ from airflow.models.taskreschedule import TaskReschedule from airflow.models.trigger import Trigger from airflow.models.xcom import XComModel +from airflow.sdk.definitions._internal.expandinput import NotFullyPopulated from airflow.sdk.definitions.taskgroup import MappedTaskGroup from airflow.utils import timezone from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState @@ -244,7 +245,9 @@ def ti_run( ) if dag := dag_bag.get_dag(ti.dag_id): - upstream_map_indexes = dict(_get_upstream_map_indexes(dag.get_task(ti.task_id), ti.map_index)) + upstream_map_indexes = dict( + _get_upstream_map_indexes(dag.get_task(ti.task_id), ti.map_index, ti.run_id, session) + ) else: upstream_map_indexes = None @@ -274,7 +277,7 @@ def ti_run( def _get_upstream_map_indexes( - task: Operator, ti_map_index: int + task: Operator, ti_map_index: int, run_id: str, session: SessionDep ) -> Iterator[tuple[str, int | list[int] | None]]: for upstream_task in task.upstream_list: map_indexes: int | list[int] | None @@ -287,8 +290,17 @@ def _get_upstream_map_indexes( map_indexes = ti_map_index else: # tasks not in the same mapped task group - # the upstream mapped task group should combine the xcom as a list and return it - mapped_ti_count: int = upstream_task.task_group.get_parse_time_mapped_ti_count() + # the upstream mapped task group should combine the return xcom as a list and return it + mapped_ti_count: int + upstream_mapped_group = upstream_task.task_group + try: + # for cases that does not need to resolve xcom + mapped_ti_count = upstream_mapped_group.get_parse_time_mapped_ti_count() + except NotFullyPopulated: + # for cases that needs to resolve xcom to get the correct count + mapped_ti_count = upstream_mapped_group._expand_input.get_total_map_length( + run_id, session=session + ) map_indexes = list(range(mapped_ti_count)) if mapped_ti_count is not None else None yield upstream_task.task_id, map_indexes diff --git a/airflow-core/src/airflow/models/expandinput.py b/airflow-core/src/airflow/models/expandinput.py index f3e6aab168076..b126c6f24b07f 100644 --- a/airflow-core/src/airflow/models/expandinput.py +++ b/airflow-core/src/airflow/models/expandinput.py @@ -20,7 +20,7 @@ import functools import operator from collections.abc import Iterable, Sized -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar, Union import attrs @@ -32,7 +32,6 @@ from airflow.sdk.definitions._internal.expandinput import ( DictOfListsExpandInput, - ExpandInput, ListOfDictsExpandInput, MappedArgument, NotFullyPopulated, @@ -62,6 +61,8 @@ def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArg class SchedulerDictOfListsExpandInput: value: dict + EXPAND_INPUT_TYPE: ClassVar[str] = "dict-of-lists" + def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: """Generate kwargs with values available on parse-time.""" return ((k, v) for k, v in self.value.items() if not _needs_run_time_resolution(v)) @@ -114,6 +115,8 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int: class SchedulerListOfDictsExpandInput: value: list + EXPAND_INPUT_TYPE: ClassVar[str] = "list-of-dicts" + def get_parse_time_mapped_ti_count(self) -> int: if isinstance(self.value, Sized): return len(self.value) @@ -130,11 +133,13 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int: return length -_EXPAND_INPUT_TYPES = { +_EXPAND_INPUT_TYPES: dict[str, type[SchedulerExpandInput]] = { "dict-of-lists": SchedulerDictOfListsExpandInput, "list-of-dicts": SchedulerListOfDictsExpandInput, } +SchedulerExpandInput = Union[SchedulerDictOfListsExpandInput, SchedulerListOfDictsExpandInput] + -def create_expand_input(kind: str, value: Any) -> ExpandInput: +def create_expand_input(kind: str, value: Any) -> SchedulerExpandInput: return _EXPAND_INPUT_TYPES[kind](value) diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index dbd55c1adde6f..2a8f3cd9da671 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -100,7 +100,7 @@ from inspect import Parameter from airflow.models import DagRun - from airflow.models.expandinput import ExpandInput + from airflow.models.expandinput import SchedulerExpandInput from airflow.sdk import BaseOperatorLink from airflow.sdk.definitions._internal.node import DAGNode from airflow.sdk.types import Operator @@ -577,7 +577,7 @@ def validate_expand_input_value(cls, value: _ExpandInputOriginalValue) -> None: possible ExpandInput cases. """ - def deref(self, dag: DAG) -> ExpandInput: + def deref(self, dag: DAG) -> SchedulerExpandInput: """ De-reference into a concrete ExpandInput object. diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index cc2b1baa64ac6..77e3e7df3ebbd 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -34,7 +34,7 @@ from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancehistory import TaskInstanceHistory from airflow.providers.standard.operators.empty import EmptyOperator -from airflow.sdk import TaskGroup +from airflow.sdk import TaskGroup, task, task_group from airflow.utils import timezone from airflow.utils.state import State, TaskInstanceState, TerminalTIState @@ -237,6 +237,128 @@ def test_ti_run_state_to_running( ) assert response.status_code == 409 + def test_dynamic_task_mapping_with_parse_time_value(self, client, dag_maker): + """ + Test that the Task Instance upstream_map_indexes is correctly fetched when to running the Task Instances + """ + + with dag_maker("test_dynamic_task_mapping_with_parse_time_value", serialized=True): + + @task_group + def task_group_1(arg1): + @task + def group1_task_1(arg1): + return {"a": arg1} + + @task + def group1_task_2(arg2): + return arg2 + + group1_task_2(group1_task_1(arg1)) + + @task + def task2(): + return None + + task_group_1.expand(arg1=[0, 1]) >> task2() + + dr = dag_maker.create_dagrun() + for ti in dr.get_task_instances(): + ti.set_state(State.QUEUED) + dag_maker.session.flush() + + # key: (task_id, map_index) + # value: result upstream_map_indexes ({task_id: map_indexes}) + expected_upstream_map_indexes = { + # no upstream task for task_group_1.group_task_1 + ("task_group_1.group1_task_1", 0): {}, + ("task_group_1.group1_task_1", 1): {}, + # the upstream task for task_group_1.group_task_2 is task_group_1.group_task_2 + # since they are in the same task group, the upstream map index should be the same as the task + ("task_group_1.group1_task_2", 0): {"task_group_1.group1_task_1": 0}, + ("task_group_1.group1_task_2", 1): {"task_group_1.group1_task_1": 1}, + # the upstream task for task2 is the last tasks of task_group_1, which is + # task_group_1.group_task_2 + # since they are not in the same task group, the upstream map index should include all the + # expanded tasks + ("task2", -1): {"task_group_1.group1_task_2": [0, 1]}, + } + + for ti in dr.get_task_instances(): + response = client.patch( + f"/execution/task-instances/{ti.id}/run", + json={ + "state": "running", + "hostname": "random-hostname", + "unixname": "random-unixname", + "pid": 100, + "start_date": "2024-09-30T12:00:00Z", + }, + ) + + assert response.status_code == 200 + upstream_map_indexes = response.json()["upstream_map_indexes"] + assert upstream_map_indexes == expected_upstream_map_indexes[(ti.task_id, ti.map_index)] + + def test_dynamic_task_mapping_with_xcom(self, client, dag_maker, create_task_instance, session, run_task): + """ + Test that the Task Instance upstream_map_indexes is correctly fetched when to running the Task Instances with xcom + """ + from airflow.models.taskmap import TaskMap + + with dag_maker(session=session): + + @task + def task_1(): + return [0, 1] + + @task_group + def tg(x, y): + @task + def task_2(): + pass + + task_2() + + @task + def task_3(): + pass + + tg.expand(x=task_1(), y=[1, 2, 3]) >> task_3() + + dr = dag_maker.create_dagrun() + + decision = dr.task_instance_scheduling_decisions(session=session) + + # Simulate task_1 execution to produce TaskMap. + (ti_1,) = decision.schedulable_tis + # ti_1 = dr.get_task_instance(task_id="task_1") + ti_1.state = TaskInstanceState.SUCCESS + session.add(TaskMap.from_task_instance_xcom(ti_1, [0, 1])) + session.flush() + + # Now task_2 in mapped tagk group is expanded. + decision = dr.task_instance_scheduling_decisions(session=session) + for ti in decision.schedulable_tis: + ti.state = TaskInstanceState.SUCCESS + session.flush() + + decision = dr.task_instance_scheduling_decisions(session=session) + (task_3_ti,) = decision.schedulable_tis + task_3_ti.set_state(State.QUEUED) + + response = client.patch( + f"/execution/task-instances/{task_3_ti.id}/run", + json={ + "state": "running", + "hostname": "random-hostname", + "unixname": "random-unixname", + "pid": 100, + "start_date": "2024-09-30T12:00:00Z", + }, + ) + assert response.json()["upstream_map_indexes"] == {"tg.task_2": [0, 1, 2, 3, 4, 5]} + def test_next_kwargs_still_encoded(self, client, session, create_task_instance, time_machine): instant_str = "2024-09-30T12:00:00Z" instant = timezone.parse(instant_str) diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py index b2e3baaeccf69..cb24a7cc6bdb7 100644 --- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -64,13 +64,13 @@ TaskStateChangeCallback, ) from airflow.models.expandinput import ( - ExpandInput, OperatorExpandArgument, OperatorExpandKwargsArgument, ) from airflow.models.xcom_arg import XComArg from airflow.sdk.bases.operator import BaseOperator from airflow.sdk.bases.operatorlink import BaseOperatorLink + from airflow.sdk.definitions._internal.expandinput import ExpandInput from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.param import ParamsDict from airflow.sdk.types import Operator diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py b/task-sdk/src/airflow/sdk/definitions/taskgroup.py index 3363424dee68e..03cc4bbad8d14 100644 --- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py @@ -40,7 +40,7 @@ from airflow.utils.trigger_rule import TriggerRule if TYPE_CHECKING: - from airflow.models.expandinput import ExpandInput + from airflow.models.expandinput import SchedulerExpandInput from airflow.sdk.bases.operator import BaseOperator from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator from airflow.sdk.definitions._internal.mixins import DependencyMixin @@ -613,7 +613,7 @@ class MappedTaskGroup(TaskGroup): a ``@task_group`` function instead. """ - def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None: + def __init__(self, *, expand_input: SchedulerExpandInput, **kwargs: Any) -> None: super().__init__(**kwargs) self._expand_input = expand_input