|
51 | 51 | TaskInstanceState,
|
52 | 52 | )
|
53 | 53 | from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
|
| 54 | +from airflow.sdk.execution_time import task_runner |
54 | 55 | from airflow.sdk.execution_time.comms import (
|
55 | 56 | AssetEventsResult,
|
56 | 57 | AssetResult,
|
|
92 | 93 | XComResult,
|
93 | 94 | )
|
94 | 95 | from airflow.sdk.execution_time.secrets_masker import SecretsMasker
|
95 |
| -from airflow.sdk.execution_time.supervisor import BUFFER_SIZE, ActivitySubprocess, mkpipe, supervise |
| 96 | +from airflow.sdk.execution_time.supervisor import ( |
| 97 | + BUFFER_SIZE, |
| 98 | + ActivitySubprocess, |
| 99 | + InProcessSupervisorComms, |
| 100 | + InProcessTestSupervisor, |
| 101 | + mkpipe, |
| 102 | + set_supervisor_comms, |
| 103 | + supervise, |
| 104 | +) |
96 | 105 | from airflow.sdk.execution_time.task_runner import CommsDecoder
|
97 | 106 | from airflow.utils import timezone, timezone as tz
|
98 | 107 |
|
@@ -1600,3 +1609,84 @@ def test_handle_requests_api_server_error(self, watched_subprocess, mocker):
|
1600 | 1609 | "message": str(error),
|
1601 | 1610 | "detail": error.response.json(),
|
1602 | 1611 | }
|
| 1612 | + |
| 1613 | + |
| 1614 | +class TestSetSupervisorComms: |
| 1615 | + class DummyComms: |
| 1616 | + pass |
| 1617 | + |
| 1618 | + @pytest.fixture(autouse=True) |
| 1619 | + def cleanup_supervisor_comms(self): |
| 1620 | + # Ensure clean state before/after test |
| 1621 | + if hasattr(task_runner, "SUPERVISOR_COMMS"): |
| 1622 | + delattr(task_runner, "SUPERVISOR_COMMS") |
| 1623 | + yield |
| 1624 | + if hasattr(task_runner, "SUPERVISOR_COMMS"): |
| 1625 | + delattr(task_runner, "SUPERVISOR_COMMS") |
| 1626 | + |
| 1627 | + def test_set_supervisor_comms_overrides_and_restores(self): |
| 1628 | + task_runner.SUPERVISOR_COMMS = self.DummyComms() |
| 1629 | + original = task_runner.SUPERVISOR_COMMS |
| 1630 | + replacement = self.DummyComms() |
| 1631 | + |
| 1632 | + with set_supervisor_comms(replacement): |
| 1633 | + assert task_runner.SUPERVISOR_COMMS is replacement |
| 1634 | + assert task_runner.SUPERVISOR_COMMS is original |
| 1635 | + |
| 1636 | + def test_set_supervisor_comms_sets_temporarily_when_not_set(self): |
| 1637 | + assert not hasattr(task_runner, "SUPERVISOR_COMMS") |
| 1638 | + replacement = self.DummyComms() |
| 1639 | + |
| 1640 | + with set_supervisor_comms(replacement): |
| 1641 | + assert task_runner.SUPERVISOR_COMMS is replacement |
| 1642 | + assert not hasattr(task_runner, "SUPERVISOR_COMMS") |
| 1643 | + |
| 1644 | + def test_set_supervisor_comms_unsets_temporarily_when_not_set(self): |
| 1645 | + assert not hasattr(task_runner, "SUPERVISOR_COMMS") |
| 1646 | + |
| 1647 | + # This will delete an attribute that isn't set, and restore it likewise |
| 1648 | + with set_supervisor_comms(None): |
| 1649 | + assert not hasattr(task_runner, "SUPERVISOR_COMMS") |
| 1650 | + |
| 1651 | + assert not hasattr(task_runner, "SUPERVISOR_COMMS") |
| 1652 | + |
| 1653 | + |
| 1654 | +class TestInProcessTestSupervisor: |
| 1655 | + def test_inprocess_supervisor_comms_roundtrip(self): |
| 1656 | + """ |
| 1657 | + Test that InProcessSupervisorComms correctly sends a message to the supervisor, |
| 1658 | + and that the supervisor's response is received via the message queue. |
| 1659 | +
|
| 1660 | + This verifies the end-to-end communication flow: |
| 1661 | + - send_request() dispatches a message to the supervisor |
| 1662 | + - the supervisor handles the request and appends a response via send_msg() |
| 1663 | + - get_message() returns the enqueued response |
| 1664 | +
|
| 1665 | + This test mocks the supervisor's `_handle_request()` method to simulate |
| 1666 | + a simple echo-style response, avoiding full task execution. |
| 1667 | + """ |
| 1668 | + |
| 1669 | + class MinimalSupervisor(InProcessTestSupervisor): |
| 1670 | + def _handle_request(self, msg, log): |
| 1671 | + resp = VariableResult(key=msg.key, value="value") |
| 1672 | + self.send_msg(resp) |
| 1673 | + |
| 1674 | + supervisor = MinimalSupervisor( |
| 1675 | + id="test", |
| 1676 | + pid=123, |
| 1677 | + requests_fd=-1, |
| 1678 | + process=MagicMock(), |
| 1679 | + process_log=MagicMock(), |
| 1680 | + client=MagicMock(), |
| 1681 | + ) |
| 1682 | + comms = InProcessSupervisorComms(supervisor=supervisor) |
| 1683 | + supervisor.comms = comms |
| 1684 | + |
| 1685 | + test_msg = GetVariable(key="test_key") |
| 1686 | + |
| 1687 | + comms.send_request(log=MagicMock(), msg=test_msg) |
| 1688 | + |
| 1689 | + # Ensure we got back what we expect |
| 1690 | + response = comms.get_message() |
| 1691 | + assert isinstance(response, VariableResult) |
| 1692 | + assert response.value == "value" |
0 commit comments