Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class GetXcomFilterParams(BaseModel):

map_index: int = -1
include_prior_dates: bool = False
offset: int | None = None


@router.get(
Expand All @@ -141,32 +142,43 @@ def get_xcom(
params: Annotated[GetXcomFilterParams, Query()],
) -> XComResponse:
"""Get an Airflow XCom from database - not other XCom Backends."""
# The xcom_query allows no map_index to be passed. This endpoint should always return just a single item,
# so we override that query value
xcom_query = XComModel.get_many(
run_id=run_id,
key=key,
task_ids=task_id,
dag_ids=dag_id,
map_indexes=params.map_index,
include_prior_dates=params.include_prior_dates,
session=session,
)
xcom_query = xcom_query.filter(XComModel.map_index == params.map_index)
if params.offset is not None:
xcom_query = xcom_query.filter(XComModel.value.is_not(None)).order_by(None)
if params.offset >= 0:
xcom_query = xcom_query.order_by(XComModel.map_index.asc()).offset(params.offset)
else:
xcom_query = xcom_query.order_by(XComModel.map_index.desc()).offset(-1 - params.offset)
else:
xcom_query = xcom_query.filter(XComModel.map_index == params.map_index)

# We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend.
# This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead
# retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one`
# (which automatically deserializes using the backend), we avoid potential
# performance hits from retrieving large data files into the API server.
result = xcom_query.limit(1).first()
if result is None:
map_index = params.map_index
if params.offset is None:
message = (
f"XCom with {key=} map_index={params.map_index} not found for "
f"task {task_id!r} in DAG run {run_id!r} of {dag_id!r}"
)
else:
message = (
f"XCom with {key=} offset={params.offset} not found for "
f"task {task_id!r} in DAG run {run_id!r} of {dag_id!r}"
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": f"XCom with {key=} {map_index=} not found for task {task_id!r} in DAG run {run_id!r} of {dag_id!r}",
},
detail={"reason": "not_found", "message": message},
)

return XComResponse(key=key, value=result.value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from airflow.models.dagrun import DagRun
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XComModel
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.serialization.serde import deserialize, serialize
from airflow.utils.session import create_session

Expand Down Expand Up @@ -130,6 +131,86 @@ def test_xcom_access_denied(self, client, caplog):
}
assert any(msg.startswith("Checking read XCom access") for msg in caplog.messages)

@pytest.mark.parametrize(
"offset, expected_status, expected_json",
[
pytest.param(
-4,
404,
{
"detail": {
"reason": "not_found",
"message": (
"XCom with key='xcom_1' offset=-4 not found "
"for task 'task' in DAG run 'runid' of 'dag'"
),
},
},
id="-4",
),
pytest.param(-3, 200, {"key": "xcom_1", "value": "f"}, id="-3"),
pytest.param(-2, 200, {"key": "xcom_1", "value": "o"}, id="-2"),
pytest.param(-1, 200, {"key": "xcom_1", "value": "b"}, id="-1"),
pytest.param(0, 200, {"key": "xcom_1", "value": "f"}, id="0"),
pytest.param(1, 200, {"key": "xcom_1", "value": "o"}, id="1"),
pytest.param(2, 200, {"key": "xcom_1", "value": "b"}, id="2"),
pytest.param(
3,
404,
{
"detail": {
"reason": "not_found",
"message": (
"XCom with key='xcom_1' offset=3 not found "
"for task 'task' in DAG run 'runid' of 'dag'"
),
},
},
id="3",
),
],
)
def test_xcom_get_with_offset(
self,
client,
dag_maker,
session,
offset,
expected_status,
expected_json,
):
xcom_values = ["f", None, "o", "b"]

class MyOperator(EmptyOperator):
def __init__(self, *, x, **kwargs):
super().__init__(**kwargs)
self.x = x

with dag_maker(dag_id="dag"):
MyOperator.partial(task_id="task").expand(x=xcom_values)
dag_run = dag_maker.create_dagrun(run_id="runid")
tis = {ti.map_index: ti for ti in dag_run.task_instances}

for map_index, db_value in enumerate(xcom_values):
if db_value is None: # We don't put None to XCom.
continue
ti = tis[map_index]
x = XComModel(
key="xcom_1",
value=db_value,
dag_run_id=ti.dag_run.id,
run_id=ti.run_id,
task_id=ti.task_id,
dag_id=ti.dag_id,
map_index=map_index,
)
session.add(x)
session.commit()

response = client.get(f"/execution/xcoms/dag/runid/task/xcom_1?offset={offset}")
assert response.status_code == expected_status
assert response.json() == expected_json


class TestXComsSetEndpoint:
@pytest.mark.parametrize(
Expand Down
12 changes: 10 additions & 2 deletions airflow-core/tests/unit/models/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from airflow.models.taskmap import TaskMap
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk import setup, task, task_group, teardown
from airflow.sdk.execution_time.comms import XComCountResponse
from airflow.sdk.execution_time.comms import XComCountResponse, XComResult
from airflow.utils.state import TaskInstanceState
from airflow.utils.task_group import TaskGroup
from airflow.utils.trigger_rule import TriggerRule
Expand Down Expand Up @@ -1270,8 +1270,16 @@ def my_teardown(val):
) as supervisor_comms:
# TODO: TaskSDK: this is a bit of a hack that we need to stub this at all. `dag.test()` should
# really work without this!
supervisor_comms.get_message.return_value = XComCountResponse(len=3)
supervisor_comms.get_message.side_effect = [
XComCountResponse(len=3),
XComResult(key="return_value", value=1),
XComCountResponse(len=3),
XComResult(key="return_value", value=2),
XComCountResponse(len=3),
XComResult(key="return_value", value=3),
]
dr = dag.test()
assert supervisor_comms.get_message.call_count == 6
states = self.get_states(dr)
expected = {
"tg_1.my_pre_setup": "success",
Expand Down
36 changes: 36 additions & 0 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,42 @@ def delete(
# decouple from the server response string
return OKResponse(ok=True)

def get_sequence_item(
self,
dag_id: str,
run_id: str,
task_id: str,
key: str,
offset: int,
) -> XComResponse | ErrorResponse:
params = {"offset": offset}
try:
resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params)
except ServerResponseError as e:
if e.response.status_code == HTTPStatus.NOT_FOUND:
log.error(
"XCom not found",
dag_id=dag_id,
run_id=run_id,
task_id=task_id,
key=key,
offset=offset,
detail=e.detail,
status_code=e.response.status_code,
)
return ErrorResponse(
error=ErrorType.XCOM_NOT_FOUND,
detail={
"dag_id": dag_id,
"run_id": run_id,
"task_id": task_id,
"key": key,
"offset": offset,
},
)
raise
return XComResponse.model_validate_json(resp.read())


class AssetOperations:
__slots__ = ("client",)
Expand Down
2 changes: 1 addition & 1 deletion task-sdk/src/airflow/sdk/definitions/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def resolve(self, context: Mapping[str, Any]) -> Any:
task_id = self.operator.task_id

if self.operator.is_mapped:
return LazyXComSequence[Any](xcom_arg=self, ti=ti)
return LazyXComSequence(xcom_arg=self, ti=ti)
tg = self.operator.get_closest_mapped_task_group()
result = None
if tg is None:
Expand Down
10 changes: 10 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,15 @@ class GetXComCount(BaseModel):
type: Literal["GetNumberXComs"] = "GetNumberXComs"


class GetXComSequenceItem(BaseModel):
key: str
dag_id: str
run_id: str
task_id: str
offset: int
type: Literal["GetXComSequenceItem"] = "GetXComSequenceItem"


class SetXCom(BaseModel):
key: str
value: Annotated[
Expand Down Expand Up @@ -605,6 +614,7 @@ class GetDRCount(BaseModel):
GetVariable,
GetXCom,
GetXComCount,
GetXComSequenceItem,
PutVariable,
RescheduleTask,
RetryTask,
Expand Down
103 changes: 55 additions & 48 deletions task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ def __next__(self) -> T:
if self.index < 0:
# When iterating backwards, avoid extra HTTP request
raise StopIteration()
val = self.seq._get_item(self.index)
if val is None:
# None isn't the best signal (it's bad in fact) but it's the best we can do until https://github.com/apache/airflow/issues/46426
raise StopIteration()
try:
val = self.seq[self.index]
except IndexError:
raise StopIteration from None
self.index += self.dir
return val

Expand Down Expand Up @@ -109,52 +109,59 @@ def __getitem__(self, key: int) -> T: ...
def __getitem__(self, key: slice) -> Sequence[T]: ...

def __getitem__(self, key: int | slice) -> T | Sequence[T]:
if isinstance(key, int):
if key >= 0:
return self._get_item(key)
# val[-1] etc.
return self._get_item(len(self) + key)
if not isinstance(key, (int, slice)):
raise TypeError(f"Sequence indices must be integers or slices, not {type(key).__name__}")

if isinstance(key, slice):
# This implements the slicing syntax. We want to optimize negative slicing (e.g. seq[-10:]) by not
# doing an additional COUNT query (via HEAD http request) if possible. We can do this unless the
# start and stop have different signs (i.e. one is positive and another negative).
...
"""
Todo?
elif isinstance(key, slice):
start, stop, reverse = _coerce_slice(key)
if start >= 0:
if stop is None:
stmt = self._select_asc.offset(start)
elif stop >= 0:
stmt = self._select_asc.slice(start, stop)
else:
stmt = self._select_asc.slice(start, len(self) + stop)
rows = [self._process_row(row) for row in self._session.execute(stmt)]
if reverse:
rows.reverse()
else:
if stop is None:
stmt = self._select_desc.limit(-start)
elif stop < 0:
stmt = self._select_desc.slice(-stop, -start)
else:
stmt = self._select_desc.slice(len(self) - stop, -start)
rows = [self._process_row(row) for row in self._session.execute(stmt)]
if not reverse:
rows.reverse()
return rows
"""
raise TypeError(f"Sequence indices must be integers or slices, not {type(key).__name__}")

def _get_item(self, index: int) -> T:
# TODO: maybe we need to call SUPERVISOR_COMMS manually so we can handle not found here?
return self._ti.xcom_pull(
task_ids=self._xcom_arg.operator.task_id,
key=self._xcom_arg.key,
map_indexes=index,
)
raise TypeError("slice is not implemented yet")
# TODO...
# This implements the slicing syntax. We want to optimize negative slicing (e.g. seq[-10:]) by not
# doing an additional COUNT query (via HEAD http request) if possible. We can do this unless the
# start and stop have different signs (i.e. one is positive and another negative).
# start, stop, reverse = _coerce_slice(key)
# if start >= 0:
# if stop is None:
# stmt = self._select_asc.offset(start)
# elif stop >= 0:
# stmt = self._select_asc.slice(start, stop)
# else:
# stmt = self._select_asc.slice(start, len(self) + stop)
# rows = [self._process_row(row) for row in self._session.execute(stmt)]
# if reverse:
# rows.reverse()
# else:
# if stop is None:
# stmt = self._select_desc.limit(-start)
# elif stop < 0:
# stmt = self._select_desc.slice(-stop, -start)
# else:
# stmt = self._select_desc.slice(len(self) - stop, -start)
# rows = [self._process_row(row) for row in self._session.execute(stmt)]
# if not reverse:
# rows.reverse()
# return rows

from airflow.sdk.bases.xcom import BaseXCom
from airflow.sdk.execution_time.comms import GetXComSequenceItem, XComResult
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

with SUPERVISOR_COMMS.lock:
source = (xcom_arg := self._xcom_arg).operator
SUPERVISOR_COMMS.send_request(
log=log,
msg=GetXComSequenceItem(
key=xcom_arg.key,
dag_id=source.dag_id,
task_id=source.task_id,
run_id=self._ti.run_id,
offset=key,
),
)
msg = SUPERVISOR_COMMS.get_message()

if not isinstance(msg, XComResult):
raise IndexError(key)
return BaseXCom.deserialize_value(msg)


def _coerce_index(value: Any) -> int | None:
Expand Down
Loading