Skip to content

Commit 19c840a

Browse files
authored
Merge branch 'main' into issue_50330
2 parents e96c9b1 + 1ab2474 commit 19c840a

File tree

3 files changed

+105
-8
lines changed

3 files changed

+105
-8
lines changed

providers/amazon/tests/system/amazon/aws/utils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,13 @@ def _fetch_from_ssm(key: str, test_name: str | None = None) -> str:
113113
log.info("No boto credentials found: %s", e)
114114
except ClientError as e:
115115
log.info("Client error when connecting to SSM: %s", e)
116-
except hook.conn.exceptions.ParameterNotFound as e:
117-
log.info("SSM does not contain any parameter for this test: %s", e)
118116
except KeyError as e:
119117
log.info(
120118
"SSM contains one parameter for this test, but not the requested value: %s",
121119
e,
122120
)
121+
except Exception as e:
122+
log.info("SSM does not contain any parameter for this test: %s", e)
123123
return value
124124

125125

task-sdk/src/airflow/sdk/execution_time/supervisor.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1320,15 +1320,22 @@ def set_supervisor_comms(temp_comms):
13201320
"""
13211321
from airflow.sdk.execution_time import task_runner
13221322

1323-
old = getattr(task_runner, "SUPERVISOR_COMMS", None)
1324-
task_runner.SUPERVISOR_COMMS = temp_comms
1323+
sentinel = object()
1324+
old = getattr(task_runner, "SUPERVISOR_COMMS", sentinel)
1325+
1326+
if temp_comms is not None:
1327+
task_runner.SUPERVISOR_COMMS = temp_comms
1328+
elif old is not sentinel:
1329+
delattr(task_runner, "SUPERVISOR_COMMS")
1330+
13251331
try:
13261332
yield
13271333
finally:
1328-
if old is not None:
1329-
task_runner.SUPERVISOR_COMMS = old
1334+
if old is sentinel:
1335+
if hasattr(task_runner, "SUPERVISOR_COMMS"):
1336+
delattr(task_runner, "SUPERVISOR_COMMS")
13301337
else:
1331-
delattr(task_runner, "SUPERVISOR_COMMS")
1338+
task_runner.SUPERVISOR_COMMS = old
13321339

13331340

13341341
def run_task_in_process(ti: TaskInstance, task) -> TaskRunResult:

task-sdk/tests/task_sdk/execution_time/test_supervisor.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
TaskInstanceState,
5252
)
5353
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
54+
from airflow.sdk.execution_time import task_runner
5455
from airflow.sdk.execution_time.comms import (
5556
AssetEventsResult,
5657
AssetResult,
@@ -92,7 +93,15 @@
9293
XComResult,
9394
)
9495
from airflow.sdk.execution_time.secrets_masker import SecretsMasker
95-
from airflow.sdk.execution_time.supervisor import BUFFER_SIZE, ActivitySubprocess, mkpipe, supervise
96+
from airflow.sdk.execution_time.supervisor import (
97+
BUFFER_SIZE,
98+
ActivitySubprocess,
99+
InProcessSupervisorComms,
100+
InProcessTestSupervisor,
101+
mkpipe,
102+
set_supervisor_comms,
103+
supervise,
104+
)
96105
from airflow.sdk.execution_time.task_runner import CommsDecoder
97106
from airflow.utils import timezone, timezone as tz
98107

@@ -1600,3 +1609,84 @@ def test_handle_requests_api_server_error(self, watched_subprocess, mocker):
16001609
"message": str(error),
16011610
"detail": error.response.json(),
16021611
}
1612+
1613+
1614+
class TestSetSupervisorComms:
1615+
class DummyComms:
1616+
pass
1617+
1618+
@pytest.fixture(autouse=True)
1619+
def cleanup_supervisor_comms(self):
1620+
# Ensure clean state before/after test
1621+
if hasattr(task_runner, "SUPERVISOR_COMMS"):
1622+
delattr(task_runner, "SUPERVISOR_COMMS")
1623+
yield
1624+
if hasattr(task_runner, "SUPERVISOR_COMMS"):
1625+
delattr(task_runner, "SUPERVISOR_COMMS")
1626+
1627+
def test_set_supervisor_comms_overrides_and_restores(self):
1628+
task_runner.SUPERVISOR_COMMS = self.DummyComms()
1629+
original = task_runner.SUPERVISOR_COMMS
1630+
replacement = self.DummyComms()
1631+
1632+
with set_supervisor_comms(replacement):
1633+
assert task_runner.SUPERVISOR_COMMS is replacement
1634+
assert task_runner.SUPERVISOR_COMMS is original
1635+
1636+
def test_set_supervisor_comms_sets_temporarily_when_not_set(self):
1637+
assert not hasattr(task_runner, "SUPERVISOR_COMMS")
1638+
replacement = self.DummyComms()
1639+
1640+
with set_supervisor_comms(replacement):
1641+
assert task_runner.SUPERVISOR_COMMS is replacement
1642+
assert not hasattr(task_runner, "SUPERVISOR_COMMS")
1643+
1644+
def test_set_supervisor_comms_unsets_temporarily_when_not_set(self):
1645+
assert not hasattr(task_runner, "SUPERVISOR_COMMS")
1646+
1647+
# This will delete an attribute that isn't set, and restore it likewise
1648+
with set_supervisor_comms(None):
1649+
assert not hasattr(task_runner, "SUPERVISOR_COMMS")
1650+
1651+
assert not hasattr(task_runner, "SUPERVISOR_COMMS")
1652+
1653+
1654+
class TestInProcessTestSupervisor:
1655+
def test_inprocess_supervisor_comms_roundtrip(self):
1656+
"""
1657+
Test that InProcessSupervisorComms correctly sends a message to the supervisor,
1658+
and that the supervisor's response is received via the message queue.
1659+
1660+
This verifies the end-to-end communication flow:
1661+
- send_request() dispatches a message to the supervisor
1662+
- the supervisor handles the request and appends a response via send_msg()
1663+
- get_message() returns the enqueued response
1664+
1665+
This test mocks the supervisor's `_handle_request()` method to simulate
1666+
a simple echo-style response, avoiding full task execution.
1667+
"""
1668+
1669+
class MinimalSupervisor(InProcessTestSupervisor):
1670+
def _handle_request(self, msg, log):
1671+
resp = VariableResult(key=msg.key, value="value")
1672+
self.send_msg(resp)
1673+
1674+
supervisor = MinimalSupervisor(
1675+
id="test",
1676+
pid=123,
1677+
requests_fd=-1,
1678+
process=MagicMock(),
1679+
process_log=MagicMock(),
1680+
client=MagicMock(),
1681+
)
1682+
comms = InProcessSupervisorComms(supervisor=supervisor)
1683+
supervisor.comms = comms
1684+
1685+
test_msg = GetVariable(key="test_key")
1686+
1687+
comms.send_request(log=MagicMock(), msg=test_msg)
1688+
1689+
# Ensure we got back what we expect
1690+
response = comms.get_message()
1691+
assert isinstance(response, VariableResult)
1692+
assert response.value == "value"

0 commit comments

Comments
 (0)