Skip to content

Commit 9a18b9f

Browse files
github-actions[bot]uranusjrtirkarthi
authored
[v3-0-test] Implement offset to get the xcom for a given task by offset. (#50011) (#50048)
(cherry picked from commit 58c736a) Co-authored-by: Tzu-ping Chung <[email protected]> Co-authored-by: Karthikeyan Singaravelan <[email protected]>
1 parent 9ff257f commit 9a18b9f

File tree

10 files changed

+386
-60
lines changed

10 files changed

+386
-60
lines changed

airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class GetXcomFilterParams(BaseModel):
126126

127127
map_index: int = -1
128128
include_prior_dates: bool = False
129+
offset: int | None = None
129130

130131

131132
@router.get(
@@ -141,32 +142,43 @@ def get_xcom(
141142
params: Annotated[GetXcomFilterParams, Query()],
142143
) -> XComResponse:
143144
"""Get an Airflow XCom from database - not other XCom Backends."""
144-
# The xcom_query allows no map_index to be passed. This endpoint should always return just a single item,
145-
# so we override that query value
146145
xcom_query = XComModel.get_many(
147146
run_id=run_id,
148147
key=key,
149148
task_ids=task_id,
150149
dag_ids=dag_id,
151-
map_indexes=params.map_index,
152150
include_prior_dates=params.include_prior_dates,
153151
session=session,
154152
)
155-
xcom_query = xcom_query.filter(XComModel.map_index == params.map_index)
153+
if params.offset is not None:
154+
xcom_query = xcom_query.filter(XComModel.value.is_not(None)).order_by(None)
155+
if params.offset >= 0:
156+
xcom_query = xcom_query.order_by(XComModel.map_index.asc()).offset(params.offset)
157+
else:
158+
xcom_query = xcom_query.order_by(XComModel.map_index.desc()).offset(-1 - params.offset)
159+
else:
160+
xcom_query = xcom_query.filter(XComModel.map_index == params.map_index)
161+
156162
# We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend.
157163
# This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead
158164
# retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one`
159165
# (which automatically deserializes using the backend), we avoid potential
160166
# performance hits from retrieving large data files into the API server.
161167
result = xcom_query.limit(1).first()
162168
if result is None:
163-
map_index = params.map_index
169+
if params.offset is None:
170+
message = (
171+
f"XCom with {key=} map_index={params.map_index} not found for "
172+
f"task {task_id!r} in DAG run {run_id!r} of {dag_id!r}"
173+
)
174+
else:
175+
message = (
176+
f"XCom with {key=} offset={params.offset} not found for "
177+
f"task {task_id!r} in DAG run {run_id!r} of {dag_id!r}"
178+
)
164179
raise HTTPException(
165180
status_code=status.HTTP_404_NOT_FOUND,
166-
detail={
167-
"reason": "not_found",
168-
"message": f"XCom with {key=} {map_index=} not found for task {task_id!r} in DAG run {run_id!r} of {dag_id!r}",
169-
},
181+
detail={"reason": "not_found", "message": message},
170182
)
171183

172184
return XComResponse(key=key, value=result.value)

airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from airflow.models.dagrun import DagRun
3030
from airflow.models.taskmap import TaskMap
3131
from airflow.models.xcom import XComModel
32+
from airflow.providers.standard.operators.empty import EmptyOperator
3233
from airflow.serialization.serde import deserialize, serialize
3334
from airflow.utils.session import create_session
3435

@@ -130,6 +131,86 @@ def test_xcom_access_denied(self, client, caplog):
130131
}
131132
assert any(msg.startswith("Checking read XCom access") for msg in caplog.messages)
132133

134+
@pytest.mark.parametrize(
135+
"offset, expected_status, expected_json",
136+
[
137+
pytest.param(
138+
-4,
139+
404,
140+
{
141+
"detail": {
142+
"reason": "not_found",
143+
"message": (
144+
"XCom with key='xcom_1' offset=-4 not found "
145+
"for task 'task' in DAG run 'runid' of 'dag'"
146+
),
147+
},
148+
},
149+
id="-4",
150+
),
151+
pytest.param(-3, 200, {"key": "xcom_1", "value": "f"}, id="-3"),
152+
pytest.param(-2, 200, {"key": "xcom_1", "value": "o"}, id="-2"),
153+
pytest.param(-1, 200, {"key": "xcom_1", "value": "b"}, id="-1"),
154+
pytest.param(0, 200, {"key": "xcom_1", "value": "f"}, id="0"),
155+
pytest.param(1, 200, {"key": "xcom_1", "value": "o"}, id="1"),
156+
pytest.param(2, 200, {"key": "xcom_1", "value": "b"}, id="2"),
157+
pytest.param(
158+
3,
159+
404,
160+
{
161+
"detail": {
162+
"reason": "not_found",
163+
"message": (
164+
"XCom with key='xcom_1' offset=3 not found "
165+
"for task 'task' in DAG run 'runid' of 'dag'"
166+
),
167+
},
168+
},
169+
id="3",
170+
),
171+
],
172+
)
173+
def test_xcom_get_with_offset(
174+
self,
175+
client,
176+
dag_maker,
177+
session,
178+
offset,
179+
expected_status,
180+
expected_json,
181+
):
182+
xcom_values = ["f", None, "o", "b"]
183+
184+
class MyOperator(EmptyOperator):
185+
def __init__(self, *, x, **kwargs):
186+
super().__init__(**kwargs)
187+
self.x = x
188+
189+
with dag_maker(dag_id="dag"):
190+
MyOperator.partial(task_id="task").expand(x=xcom_values)
191+
dag_run = dag_maker.create_dagrun(run_id="runid")
192+
tis = {ti.map_index: ti for ti in dag_run.task_instances}
193+
194+
for map_index, db_value in enumerate(xcom_values):
195+
if db_value is None: # We don't put None to XCom.
196+
continue
197+
ti = tis[map_index]
198+
x = XComModel(
199+
key="xcom_1",
200+
value=db_value,
201+
dag_run_id=ti.dag_run.id,
202+
run_id=ti.run_id,
203+
task_id=ti.task_id,
204+
dag_id=ti.dag_id,
205+
map_index=map_index,
206+
)
207+
session.add(x)
208+
session.commit()
209+
210+
response = client.get(f"/execution/xcoms/dag/runid/task/xcom_1?offset={offset}")
211+
assert response.status_code == expected_status
212+
assert response.json() == expected_json
213+
133214

134215
class TestXComsSetEndpoint:
135216
@pytest.mark.parametrize(

airflow-core/tests/unit/models/test_mappedoperator.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from airflow.models.taskmap import TaskMap
3333
from airflow.providers.standard.operators.python import PythonOperator
3434
from airflow.sdk import setup, task, task_group, teardown
35-
from airflow.sdk.execution_time.comms import XComCountResponse
35+
from airflow.sdk.execution_time.comms import XComCountResponse, XComResult
3636
from airflow.utils.state import TaskInstanceState
3737
from airflow.utils.task_group import TaskGroup
3838
from airflow.utils.trigger_rule import TriggerRule
@@ -1270,8 +1270,16 @@ def my_teardown(val):
12701270
) as supervisor_comms:
12711271
# TODO: TaskSDK: this is a bit of a hack that we need to stub this at all. `dag.test()` should
12721272
# really work without this!
1273-
supervisor_comms.get_message.return_value = XComCountResponse(len=3)
1273+
supervisor_comms.get_message.side_effect = [
1274+
XComCountResponse(len=3),
1275+
XComResult(key="return_value", value=1),
1276+
XComCountResponse(len=3),
1277+
XComResult(key="return_value", value=2),
1278+
XComCountResponse(len=3),
1279+
XComResult(key="return_value", value=3),
1280+
]
12741281
dr = dag.test()
1282+
assert supervisor_comms.get_message.call_count == 6
12751283
states = self.get_states(dr)
12761284
expected = {
12771285
"tg_1.my_pre_setup": "success",

task-sdk/src/airflow/sdk/api/client.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,42 @@ def delete(
429429
# decouple from the server response string
430430
return OKResponse(ok=True)
431431

432+
def get_sequence_item(
433+
self,
434+
dag_id: str,
435+
run_id: str,
436+
task_id: str,
437+
key: str,
438+
offset: int,
439+
) -> XComResponse | ErrorResponse:
440+
params = {"offset": offset}
441+
try:
442+
resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params)
443+
except ServerResponseError as e:
444+
if e.response.status_code == HTTPStatus.NOT_FOUND:
445+
log.error(
446+
"XCom not found",
447+
dag_id=dag_id,
448+
run_id=run_id,
449+
task_id=task_id,
450+
key=key,
451+
offset=offset,
452+
detail=e.detail,
453+
status_code=e.response.status_code,
454+
)
455+
return ErrorResponse(
456+
error=ErrorType.XCOM_NOT_FOUND,
457+
detail={
458+
"dag_id": dag_id,
459+
"run_id": run_id,
460+
"task_id": task_id,
461+
"key": key,
462+
"offset": offset,
463+
},
464+
)
465+
raise
466+
return XComResponse.model_validate_json(resp.read())
467+
432468

433469
class AssetOperations:
434470
__slots__ = ("client",)

task-sdk/src/airflow/sdk/definitions/xcom_arg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def resolve(self, context: Mapping[str, Any]) -> Any:
337337
task_id = self.operator.task_id
338338

339339
if self.operator.is_mapped:
340-
return LazyXComSequence[Any](xcom_arg=self, ti=ti)
340+
return LazyXComSequence(xcom_arg=self, ti=ti)
341341
tg = self.operator.get_closest_mapped_task_group()
342342
result = None
343343
if tg is None:

task-sdk/src/airflow/sdk/execution_time/comms.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,15 @@ class GetXComCount(BaseModel):
441441
type: Literal["GetNumberXComs"] = "GetNumberXComs"
442442

443443

444+
class GetXComSequenceItem(BaseModel):
445+
key: str
446+
dag_id: str
447+
run_id: str
448+
task_id: str
449+
offset: int
450+
type: Literal["GetXComSequenceItem"] = "GetXComSequenceItem"
451+
452+
444453
class SetXCom(BaseModel):
445454
key: str
446455
value: Annotated[
@@ -605,6 +614,7 @@ class GetDRCount(BaseModel):
605614
GetVariable,
606615
GetXCom,
607616
GetXComCount,
617+
GetXComSequenceItem,
608618
PutVariable,
609619
RescheduleTask,
610620
RetryTask,

task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py

Lines changed: 55 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ def __next__(self) -> T:
4343
if self.index < 0:
4444
# When iterating backwards, avoid extra HTTP request
4545
raise StopIteration()
46-
val = self.seq._get_item(self.index)
47-
if val is None:
48-
# None isn't the best signal (it's bad in fact) but it's the best we can do until https://github.com/apache/airflow/issues/46426
49-
raise StopIteration()
46+
try:
47+
val = self.seq[self.index]
48+
except IndexError:
49+
raise StopIteration from None
5050
self.index += self.dir
5151
return val
5252

@@ -109,52 +109,59 @@ def __getitem__(self, key: int) -> T: ...
109109
def __getitem__(self, key: slice) -> Sequence[T]: ...
110110

111111
def __getitem__(self, key: int | slice) -> T | Sequence[T]:
112-
if isinstance(key, int):
113-
if key >= 0:
114-
return self._get_item(key)
115-
# val[-1] etc.
116-
return self._get_item(len(self) + key)
112+
if not isinstance(key, (int, slice)):
113+
raise TypeError(f"Sequence indices must be integers or slices, not {type(key).__name__}")
117114

118115
if isinstance(key, slice):
119-
# This implements the slicing syntax. We want to optimize negative slicing (e.g. seq[-10:]) by not
120-
# doing an additional COUNT query (via HEAD http request) if possible. We can do this unless the
121-
# start and stop have different signs (i.e. one is positive and another negative).
122-
...
123-
"""
124-
Todo?
125-
elif isinstance(key, slice):
126-
start, stop, reverse = _coerce_slice(key)
127-
if start >= 0:
128-
if stop is None:
129-
stmt = self._select_asc.offset(start)
130-
elif stop >= 0:
131-
stmt = self._select_asc.slice(start, stop)
132-
else:
133-
stmt = self._select_asc.slice(start, len(self) + stop)
134-
rows = [self._process_row(row) for row in self._session.execute(stmt)]
135-
if reverse:
136-
rows.reverse()
137-
else:
138-
if stop is None:
139-
stmt = self._select_desc.limit(-start)
140-
elif stop < 0:
141-
stmt = self._select_desc.slice(-stop, -start)
142-
else:
143-
stmt = self._select_desc.slice(len(self) - stop, -start)
144-
rows = [self._process_row(row) for row in self._session.execute(stmt)]
145-
if not reverse:
146-
rows.reverse()
147-
return rows
148-
"""
149-
raise TypeError(f"Sequence indices must be integers or slices, not {type(key).__name__}")
150-
151-
def _get_item(self, index: int) -> T:
152-
# TODO: maybe we need to call SUPERVISOR_COMMS manually so we can handle not found here?
153-
return self._ti.xcom_pull(
154-
task_ids=self._xcom_arg.operator.task_id,
155-
key=self._xcom_arg.key,
156-
map_indexes=index,
157-
)
116+
raise TypeError("slice is not implemented yet")
117+
# TODO...
118+
# This implements the slicing syntax. We want to optimize negative slicing (e.g. seq[-10:]) by not
119+
# doing an additional COUNT query (via HEAD http request) if possible. We can do this unless the
120+
# start and stop have different signs (i.e. one is positive and another negative).
121+
# start, stop, reverse = _coerce_slice(key)
122+
# if start >= 0:
123+
# if stop is None:
124+
# stmt = self._select_asc.offset(start)
125+
# elif stop >= 0:
126+
# stmt = self._select_asc.slice(start, stop)
127+
# else:
128+
# stmt = self._select_asc.slice(start, len(self) + stop)
129+
# rows = [self._process_row(row) for row in self._session.execute(stmt)]
130+
# if reverse:
131+
# rows.reverse()
132+
# else:
133+
# if stop is None:
134+
# stmt = self._select_desc.limit(-start)
135+
# elif stop < 0:
136+
# stmt = self._select_desc.slice(-stop, -start)
137+
# else:
138+
# stmt = self._select_desc.slice(len(self) - stop, -start)
139+
# rows = [self._process_row(row) for row in self._session.execute(stmt)]
140+
# if not reverse:
141+
# rows.reverse()
142+
# return rows
143+
144+
from airflow.sdk.bases.xcom import BaseXCom
145+
from airflow.sdk.execution_time.comms import GetXComSequenceItem, XComResult
146+
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
147+
148+
with SUPERVISOR_COMMS.lock:
149+
source = (xcom_arg := self._xcom_arg).operator
150+
SUPERVISOR_COMMS.send_request(
151+
log=log,
152+
msg=GetXComSequenceItem(
153+
key=xcom_arg.key,
154+
dag_id=source.dag_id,
155+
task_id=source.task_id,
156+
run_id=self._ti.run_id,
157+
offset=key,
158+
),
159+
)
160+
msg = SUPERVISOR_COMMS.get_message()
161+
162+
if not isinstance(msg, XComResult):
163+
raise IndexError(key)
164+
return BaseXCom.deserialize_value(msg)
158165

159166

160167
def _coerce_index(value: Any) -> int | None:

0 commit comments

Comments
 (0)