|
44 | 44 | from uuid6 import uuid7
|
45 | 45 |
|
46 | 46 | from airflow.executors.workloads import BundleInfo
|
47 |
| -from airflow.sdk import timezone |
| 47 | +from airflow.sdk import BaseOperator, timezone |
48 | 48 | from airflow.sdk.api import client as sdk_client
|
49 | 49 | from airflow.sdk.api.client import ServerResponseError
|
50 | 50 | from airflow.sdk.api.datamodels._generated import (
|
|
121 | 121 | set_supervisor_comms,
|
122 | 122 | supervise,
|
123 | 123 | )
|
| 124 | +from airflow.sdk.execution_time.task_runner import run |
124 | 125 |
|
125 | 126 | from tests_common.test_utils.config import conf_vars
|
126 | 127 |
|
@@ -341,6 +342,85 @@ def subprocess_main():
|
341 | 342 | ]
|
342 | 343 | )
|
343 | 344 |
|
| 345 | + def test_on_kill_hook_called_when_sigkilled( |
| 346 | + self, |
| 347 | + client_with_ti_start, |
| 348 | + mocked_parse, |
| 349 | + make_ti_context, |
| 350 | + mock_supervisor_comms, |
| 351 | + create_runtime_ti, |
| 352 | + make_ti_context_dict, |
| 353 | + capfd, |
| 354 | + ): |
| 355 | + main_pid = os.getpid() |
| 356 | + ti_id = "4d828a62-a417-4936-a7a6-2b3fabacecab" |
| 357 | + |
| 358 | + def handle_request(request: httpx.Request) -> httpx.Response: |
| 359 | + if request.url.path == f"/task-instances/{ti_id}/heartbeat": |
| 360 | + return httpx.Response( |
| 361 | + status_code=409, |
| 362 | + json={ |
| 363 | + "detail": { |
| 364 | + "reason": "not_running", |
| 365 | + "message": "TI is no longer in the 'running' state. Task state might be externally set and task should terminate", |
| 366 | + "current_state": "failed", |
| 367 | + } |
| 368 | + }, |
| 369 | + ) |
| 370 | + if request.url.path == f"/task-instances/{ti_id}/run": |
| 371 | + return httpx.Response(200, json=make_ti_context_dict()) |
| 372 | + return httpx.Response(status_code=204) |
| 373 | + |
| 374 | + def subprocess_main(): |
| 375 | + # Ensure we follow the "protocol" and get the startup message before we do anything |
| 376 | + CommsDecoder()._get_response() |
| 377 | + |
| 378 | + class CustomOperator(BaseOperator): |
| 379 | + def execute(self, context): |
| 380 | + for i in range(1000): |
| 381 | + print(f"Iteration {i}") |
| 382 | + sleep(1) |
| 383 | + |
| 384 | + def on_kill(self) -> None: |
| 385 | + print("On kill hook called!") |
| 386 | + |
| 387 | + task = CustomOperator(task_id="print-params") |
| 388 | + runtime_ti = create_runtime_ti( |
| 389 | + dag_id="c", |
| 390 | + task=task, |
| 391 | + conf={ |
| 392 | + "x": 3, |
| 393 | + "text": "Hello World!", |
| 394 | + "flag": False, |
| 395 | + "a_simple_list": ["one", "two", "three", "actually one value is made per line"], |
| 396 | + }, |
| 397 | + ) |
| 398 | + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) |
| 399 | + |
| 400 | + assert os.getpid() != main_pid |
| 401 | + os.kill(os.getpid(), signal.SIGTERM) |
| 402 | + # Ensure that the signal is serviced before we finish and exit the subprocess. |
| 403 | + sleep(0.5) |
| 404 | + |
| 405 | + proc = ActivitySubprocess.start( |
| 406 | + dag_rel_path=os.devnull, |
| 407 | + bundle_info=FAKE_BUNDLE, |
| 408 | + what=TaskInstance( |
| 409 | + id=ti_id, |
| 410 | + task_id="b", |
| 411 | + dag_id="c", |
| 412 | + run_id="d", |
| 413 | + try_number=1, |
| 414 | + dag_version_id=uuid7(), |
| 415 | + ), |
| 416 | + client=make_client(transport=httpx.MockTransport(handle_request)), |
| 417 | + target=subprocess_main, |
| 418 | + ) |
| 419 | + |
| 420 | + proc.wait() |
| 421 | + captured = capfd.readouterr() |
| 422 | + assert "On kill hook called!" in captured.out |
| 423 | + |
344 | 424 | def test_subprocess_sigkilled(self, client_with_ti_start):
|
345 | 425 | main_pid = os.getpid()
|
346 | 426 |
|
|
0 commit comments