Skip to content

Commit b2a9f24

Browse files
committed
Fix test mock using relevant interface
1 parent 5df6a79 commit b2a9f24

File tree

6 files changed

+241
-21
lines changed

6 files changed

+241
-21
lines changed

airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -142,20 +142,20 @@ def get_xcom(
142142
params: Annotated[GetXcomFilterParams, Query()],
143143
) -> XComResponse:
144144
"""Get an Airflow XCom from database - not other XCom Backends."""
145-
# The xcom_query allows no map_index to be passed. This endpoint should always return just a single item,
146-
# so we override that query value
147145
xcom_query = XComModel.get_many(
148146
run_id=run_id,
149147
key=key,
150148
task_ids=task_id,
151149
dag_ids=dag_id,
152-
map_indexes=None if params.offset is not None else params.map_index,
153150
include_prior_dates=params.include_prior_dates,
154-
latest_first=(params.offset is None or params.offset >= 0),
155151
session=session,
156152
)
157153
if params.offset is not None:
158-
xcom_query = xcom_query.offset(abs(params.offset))
154+
xcom_query = xcom_query.filter(XComModel.value.is_not(None)).order_by(None)
155+
if params.offset >= 0:
156+
xcom_query = xcom_query.order_by(XComModel.map_index.asc()).offset(params.offset)
157+
else:
158+
xcom_query = xcom_query.order_by(XComModel.map_index.desc()).offset(-1 - params.offset)
159159
else:
160160
xcom_query = xcom_query.filter(XComModel.map_index == params.map_index)
161161

@@ -166,13 +166,19 @@ def get_xcom(
166166
# performance hits from retrieving large data files into the API server.
167167
result = xcom_query.limit(1).first()
168168
if result is None:
169-
map_index = params.map_index
169+
if params.offset is None:
170+
message = (
171+
f"XCom with {key=} map_index={params.map_index} not found for "
172+
f"task {task_id!r} in DAG run {run_id!r} of {dag_id!r}"
173+
)
174+
else:
175+
message = (
176+
f"XCom with {key=} offset={params.offset} not found for "
177+
f"task {task_id!r} in DAG run {run_id!r} of {dag_id!r}"
178+
)
170179
raise HTTPException(
171180
status_code=status.HTTP_404_NOT_FOUND,
172-
detail={
173-
"reason": "not_found",
174-
"message": f"XCom with {key=} {map_index=} not found for task {task_id!r} in DAG run {run_id!r} of {dag_id!r}",
175-
},
181+
detail={"reason": "not_found", "message": message},
176182
)
177183

178184
return XComResponse(key=key, value=result.value)

airflow-core/src/airflow/models/xcom.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,6 @@ def get_many(
252252
map_indexes: int | Iterable[int] | None = None,
253253
include_prior_dates: bool = False,
254254
limit: int | None = None,
255-
latest_first: bool = True,
256255
session: Session = NEW_SESSION,
257256
) -> Query:
258257
"""
@@ -275,9 +274,7 @@ def get_many(
275274
returned regardless of the run it belongs to.
276275
:param session: Database session. If not given, a new session will be
277276
created for this function.
278-
:param limit: Limiting returning XComs.
279-
:param latest_first: If *True* (default), returning XComs are ordered
280-
latest-first. Otherwise earlier XComs are returned first.
277+
:param limit: Limiting returning XComs
281278
"""
282279
from airflow.models.dagrun import DagRun
283280

@@ -321,10 +318,7 @@ def get_many(
321318
else:
322319
query = query.filter(cls.run_id == run_id)
323320

324-
if latest_first:
325-
query = query.order_by(DagRun.logical_date.desc(), cls.timestamp.desc())
326-
else:
327-
query = query.order_by(DagRun.logical_date.asc(), cls.timestamp.asc())
321+
query = query.order_by(DagRun.logical_date.desc(), cls.timestamp.desc())
328322
if limit:
329323
return query.limit(limit)
330324
return query

airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from airflow.models.dagrun import DagRun
3030
from airflow.models.taskmap import TaskMap
3131
from airflow.models.xcom import XComModel
32+
from airflow.providers.standard.operators.empty import EmptyOperator
3233
from airflow.serialization.serde import deserialize, serialize
3334
from airflow.utils.session import create_session
3435

@@ -130,6 +131,86 @@ def test_xcom_access_denied(self, client, caplog):
130131
}
131132
assert any(msg.startswith("Checking read XCom access") for msg in caplog.messages)
132133

134+
@pytest.mark.parametrize(
135+
"offset, expected_status, expected_json",
136+
[
137+
pytest.param(
138+
-4,
139+
404,
140+
{
141+
"detail": {
142+
"reason": "not_found",
143+
"message": (
144+
"XCom with key='xcom_1' offset=-4 not found "
145+
"for task 'task' in DAG run 'runid' of 'dag'"
146+
),
147+
},
148+
},
149+
id="-4",
150+
),
151+
pytest.param(-3, 200, {"key": "xcom_1", "value": "f"}, id="-3"),
152+
pytest.param(-2, 200, {"key": "xcom_1", "value": "o"}, id="-2"),
153+
pytest.param(-1, 200, {"key": "xcom_1", "value": "b"}, id="-1"),
154+
pytest.param(0, 200, {"key": "xcom_1", "value": "f"}, id="0"),
155+
pytest.param(1, 200, {"key": "xcom_1", "value": "o"}, id="1"),
156+
pytest.param(2, 200, {"key": "xcom_1", "value": "b"}, id="2"),
157+
pytest.param(
158+
3,
159+
404,
160+
{
161+
"detail": {
162+
"reason": "not_found",
163+
"message": (
164+
"XCom with key='xcom_1' offset=3 not found "
165+
"for task 'task' in DAG run 'runid' of 'dag'"
166+
),
167+
},
168+
},
169+
id="3",
170+
),
171+
],
172+
)
173+
def test_xcom_get_with_offset(
174+
self,
175+
client,
176+
dag_maker,
177+
session,
178+
offset,
179+
expected_status,
180+
expected_json,
181+
):
182+
xcom_values = ["f", None, "o", "b"]
183+
184+
class MyOperator(EmptyOperator):
185+
def __init__(self, *, x, **kwargs):
186+
super().__init__(**kwargs)
187+
self.x = x
188+
189+
with dag_maker(dag_id="dag"):
190+
MyOperator.partial(task_id="task").expand(x=xcom_values)
191+
dag_run = dag_maker.create_dagrun(run_id="runid")
192+
tis = {ti.map_index: ti for ti in dag_run.task_instances}
193+
194+
for map_index, db_value in enumerate(xcom_values):
195+
if db_value is None: # We don't put None to XCom.
196+
continue
197+
ti = tis[map_index]
198+
x = XComModel(
199+
key="xcom_1",
200+
value=db_value,
201+
dag_run_id=ti.dag_run.id,
202+
run_id=ti.run_id,
203+
task_id=ti.task_id,
204+
dag_id=ti.dag_id,
205+
map_index=map_index,
206+
)
207+
session.add(x)
208+
session.commit()
209+
210+
response = client.get(f"/execution/xcoms/dag/runid/task/xcom_1?offset={offset}")
211+
assert response.status_code == expected_status
212+
assert response.json() == expected_json
213+
133214

134215
class TestXComsSetEndpoint:
135216
@pytest.mark.parametrize(

airflow-core/tests/unit/models/test_mappedoperator.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from airflow.models.taskmap import TaskMap
3333
from airflow.providers.standard.operators.python import PythonOperator
3434
from airflow.sdk import setup, task, task_group, teardown
35-
from airflow.sdk.execution_time.comms import XComCountResponse
35+
from airflow.sdk.execution_time.comms import XComCountResponse, XComResult
3636
from airflow.utils.state import TaskInstanceState
3737
from airflow.utils.task_group import TaskGroup
3838
from airflow.utils.trigger_rule import TriggerRule
@@ -1270,8 +1270,16 @@ def my_teardown(val):
12701270
) as supervisor_comms:
12711271
# TODO: TaskSDK: this is a bit of a hack that we need to stub this at all. `dag.test()` should
12721272
# really work without this!
1273-
supervisor_comms.get_message.return_value = XComCountResponse(len=3)
1273+
supervisor_comms.get_message.side_effect = [
1274+
XComCountResponse(len=3),
1275+
XComResult(key="return_value", value=1),
1276+
XComCountResponse(len=3),
1277+
XComResult(key="return_value", value=2),
1278+
XComCountResponse(len=3),
1279+
XComResult(key="return_value", value=3),
1280+
]
12741281
dr = dag.test()
1282+
assert supervisor_comms.get_message.call_count == 6
12751283
states = self.get_states(dr)
12761284
expected = {
12771285
"tg_1.my_pre_setup": "success",

task-sdk/src/airflow/sdk/definitions/xcom_arg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def resolve(self, context: Mapping[str, Any]) -> Any:
337337
task_id = self.operator.task_id
338338

339339
if self.operator.is_mapped:
340-
return LazyXComSequence[Any](xcom_arg=self, ti=ti)
340+
return LazyXComSequence(xcom_arg=self, ti=ti)
341341
tg = self.operator.get_closest_mapped_task_group()
342342
result = None
343343
if tg is None:
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
from unittest.mock import ANY, Mock, call
21+
22+
import pytest
23+
24+
from airflow.sdk.exceptions import ErrorType
25+
from airflow.sdk.execution_time.comms import (
26+
ErrorResponse,
27+
GetXComCount,
28+
GetXComSequenceItem,
29+
XComCountResponse,
30+
XComResult,
31+
)
32+
from airflow.sdk.execution_time.lazy_sequence import LazyXComSequence
33+
34+
35+
@pytest.fixture
36+
def mock_operator():
37+
return Mock(spec=["dag_id", "task_id"], dag_id="dag", task_id="task")
38+
39+
40+
@pytest.fixture
41+
def mock_xcom_arg(mock_operator):
42+
return Mock(spec=["operator", "key"], operator=mock_operator, key="return_value")
43+
44+
45+
@pytest.fixture
46+
def mock_ti():
47+
return Mock(spec=["run_id"], run_id="run")
48+
49+
50+
@pytest.fixture
51+
def lazy_sequence(mock_xcom_arg, mock_ti):
52+
return LazyXComSequence(mock_xcom_arg, mock_ti)
53+
54+
55+
def test_len(mock_supervisor_comms, lazy_sequence):
56+
mock_supervisor_comms.get_message.return_value = XComCountResponse(len=3)
57+
assert len(lazy_sequence) == 3
58+
assert mock_supervisor_comms.send_request.mock_calls == [
59+
call(log=ANY, msg=GetXComCount(key="return_value", dag_id="dag", task_id="task", run_id="run")),
60+
]
61+
62+
63+
def test_iter(mock_supervisor_comms, lazy_sequence):
64+
it = iter(lazy_sequence)
65+
66+
mock_supervisor_comms.get_message.side_effect = [
67+
XComResult(key="return_value", value="f"),
68+
ErrorResponse(error=ErrorType.XCOM_NOT_FOUND, detail={"oops": "sorry!"}),
69+
]
70+
assert list(it) == ["f"]
71+
assert mock_supervisor_comms.send_request.mock_calls == [
72+
call(
73+
log=ANY,
74+
msg=GetXComSequenceItem(
75+
key="return_value",
76+
dag_id="dag",
77+
task_id="task",
78+
run_id="run",
79+
offset=0,
80+
),
81+
),
82+
call(
83+
log=ANY,
84+
msg=GetXComSequenceItem(
85+
key="return_value",
86+
dag_id="dag",
87+
task_id="task",
88+
run_id="run",
89+
offset=1,
90+
),
91+
),
92+
]
93+
94+
95+
def test_getitem_index(mock_supervisor_comms, lazy_sequence):
96+
mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value="f")
97+
assert lazy_sequence[4] == "f"
98+
assert mock_supervisor_comms.send_request.mock_calls == [
99+
call(
100+
log=ANY,
101+
msg=GetXComSequenceItem(
102+
key="return_value",
103+
dag_id="dag",
104+
task_id="task",
105+
run_id="run",
106+
offset=4,
107+
),
108+
),
109+
]
110+
111+
112+
def test_getitem_indexerror(mock_supervisor_comms, lazy_sequence):
113+
mock_supervisor_comms.get_message.return_value = ErrorResponse(
114+
error=ErrorType.XCOM_NOT_FOUND,
115+
detail={"oops": "sorry!"},
116+
)
117+
with pytest.raises(IndexError) as ctx:
118+
lazy_sequence[4]
119+
assert ctx.value.args == (4,)
120+
assert mock_supervisor_comms.send_request.mock_calls == [
121+
call(
122+
log=ANY,
123+
msg=GetXComSequenceItem(
124+
key="return_value",
125+
dag_id="dag",
126+
task_id="task",
127+
run_id="run",
128+
offset=4,
129+
),
130+
),
131+
]

0 commit comments

Comments
 (0)