Skip to content

Commit 10fee47

Browse files
gopidesupavanashb
andcommitted
Make BaseOperator on_kill functionality work with TaskSDK (apache#53718)
* Make BaseOperator on_kill functionality work with TaskSDK * Fix static checks * Resolve review comments * Update task-sdk/tests/task_sdk/execution_time/test_supervisor.py Co-authored-by: Ash Berlin-Taylor <[email protected]> --------- Co-authored-by: Ash Berlin-Taylor <[email protected]>
1 parent 7484e9b commit 10fee47

File tree

2 files changed

+94
-0
lines changed

2 files changed

+94
-0
lines changed

task-sdk/src/airflow/sdk/execution_time/task_runner.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,8 @@ def run(
832832
log: Logger,
833833
) -> tuple[TaskInstanceState, ToSupervisor | None, BaseException | None]:
834834
"""Run the task in this process."""
835+
import signal
836+
835837
from airflow.exceptions import (
836838
AirflowException,
837839
AirflowFailException,
@@ -849,6 +851,17 @@ def run(
849851
assert ti.task is not None
850852
assert isinstance(ti.task, BaseOperator)
851853

854+
parent_pid = os.getpid()
855+
856+
def _on_term(signum, frame):
857+
pid = os.getpid()
858+
if pid != parent_pid:
859+
return
860+
861+
ti.task.on_kill()
862+
863+
signal.signal(signal.SIGTERM, _on_term)
864+
852865
msg: ToSupervisor | None = None
853866
state: TaskInstanceState
854867
error: BaseException | None = None

task-sdk/tests/task_sdk/execution_time/test_supervisor.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from uuid6 import uuid7
4545

4646
from airflow.executors.workloads import BundleInfo
47+
from airflow.sdk import BaseOperator
4748
from airflow.sdk.api import client as sdk_client
4849
from airflow.sdk.api.client import ServerResponseError
4950
from airflow.sdk.api.datamodels._generated import (
@@ -113,6 +114,7 @@
113114
set_supervisor_comms,
114115
supervise,
115116
)
117+
from airflow.sdk.execution_time.task_runner import run
116118
from airflow.utils import timezone, timezone as tz
117119

118120
if TYPE_CHECKING:
@@ -330,6 +332,85 @@ def subprocess_main():
330332
]
331333
)
332334

335+
def test_on_kill_hook_called_when_sigkilled(
336+
self,
337+
client_with_ti_start,
338+
mocked_parse,
339+
make_ti_context,
340+
mock_supervisor_comms,
341+
create_runtime_ti,
342+
make_ti_context_dict,
343+
capfd,
344+
):
345+
main_pid = os.getpid()
346+
ti_id = "4d828a62-a417-4936-a7a6-2b3fabacecab"
347+
348+
def handle_request(request: httpx.Request) -> httpx.Response:
349+
if request.url.path == f"/task-instances/{ti_id}/heartbeat":
350+
return httpx.Response(
351+
status_code=409,
352+
json={
353+
"detail": {
354+
"reason": "not_running",
355+
"message": "TI is no longer in the 'running' state. Task state might be externally set and task should terminate",
356+
"current_state": "failed",
357+
}
358+
},
359+
)
360+
if request.url.path == f"/task-instances/{ti_id}/run":
361+
return httpx.Response(200, json=make_ti_context_dict())
362+
return httpx.Response(status_code=204)
363+
364+
def subprocess_main():
365+
# Ensure we follow the "protocol" and get the startup message before we do anything
366+
CommsDecoder()._get_response()
367+
368+
class CustomOperator(BaseOperator):
369+
def execute(self, context):
370+
for i in range(1000):
371+
print(f"Iteration {i}")
372+
sleep(1)
373+
374+
def on_kill(self) -> None:
375+
print("On kill hook called!")
376+
377+
task = CustomOperator(task_id="print-params")
378+
runtime_ti = create_runtime_ti(
379+
dag_id="c",
380+
task=task,
381+
conf={
382+
"x": 3,
383+
"text": "Hello World!",
384+
"flag": False,
385+
"a_simple_list": ["one", "two", "three", "actually one value is made per line"],
386+
},
387+
)
388+
run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock())
389+
390+
assert os.getpid() != main_pid
391+
os.kill(os.getpid(), signal.SIGTERM)
392+
# Ensure that the signal is serviced before we finish and exit the subprocess.
393+
sleep(0.5)
394+
395+
proc = ActivitySubprocess.start(
396+
dag_rel_path=os.devnull,
397+
bundle_info=FAKE_BUNDLE,
398+
what=TaskInstance(
399+
id=ti_id,
400+
task_id="b",
401+
dag_id="c",
402+
run_id="d",
403+
try_number=1,
404+
dag_version_id=uuid7(),
405+
),
406+
client=make_client(transport=httpx.MockTransport(handle_request)),
407+
target=subprocess_main,
408+
)
409+
410+
proc.wait()
411+
captured = capfd.readouterr()
412+
assert "On kill hook called!" in captured.out
413+
333414
def test_subprocess_sigkilled(self, client_with_ti_start):
334415
main_pid = os.getpid()
335416

0 commit comments

Comments
 (0)