Skip to content

Commit d506349

Browse files
gopidesupavanashb
andauthored
Make BaseOperator on_kill functionality work with TaskSDK (#53718)
* Make BaseOperator on_kill functionality work with TaskSDK * Fix static checks * Resolve review comments * Update task-sdk/tests/task_sdk/execution_time/test_supervisor.py Co-authored-by: Ash Berlin-Taylor <[email protected]> --------- Co-authored-by: Ash Berlin-Taylor <[email protected]>
1 parent 1752339 commit d506349

File tree

2 files changed

+94
-1
lines changed

2 files changed

+94
-1
lines changed

task-sdk/src/airflow/sdk/execution_time/task_runner.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,8 @@ def run(
860860
log: Logger,
861861
) -> tuple[TaskInstanceState, ToSupervisor | None, BaseException | None]:
862862
"""Run the task in this process."""
863+
import signal
864+
863865
from airflow.exceptions import (
864866
AirflowException,
865867
AirflowFailException,
@@ -877,6 +879,17 @@ def run(
877879
assert ti.task is not None
878880
assert isinstance(ti.task, BaseOperator)
879881

882+
parent_pid = os.getpid()
883+
884+
def _on_term(signum, frame):
885+
pid = os.getpid()
886+
if pid != parent_pid:
887+
return
888+
889+
ti.task.on_kill()
890+
891+
signal.signal(signal.SIGTERM, _on_term)
892+
880893
msg: ToSupervisor | None = None
881894
state: TaskInstanceState
882895
error: BaseException | None = None

task-sdk/tests/task_sdk/execution_time/test_supervisor.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from uuid6 import uuid7
4545

4646
from airflow.executors.workloads import BundleInfo
47-
from airflow.sdk import timezone
47+
from airflow.sdk import BaseOperator, timezone
4848
from airflow.sdk.api import client as sdk_client
4949
from airflow.sdk.api.client import ServerResponseError
5050
from airflow.sdk.api.datamodels._generated import (
@@ -121,6 +121,7 @@
121121
set_supervisor_comms,
122122
supervise,
123123
)
124+
from airflow.sdk.execution_time.task_runner import run
124125

125126
from tests_common.test_utils.config import conf_vars
126127

@@ -341,6 +342,85 @@ def subprocess_main():
341342
]
342343
)
343344

345+
def test_on_kill_hook_called_when_sigkilled(
346+
self,
347+
client_with_ti_start,
348+
mocked_parse,
349+
make_ti_context,
350+
mock_supervisor_comms,
351+
create_runtime_ti,
352+
make_ti_context_dict,
353+
capfd,
354+
):
355+
main_pid = os.getpid()
356+
ti_id = "4d828a62-a417-4936-a7a6-2b3fabacecab"
357+
358+
def handle_request(request: httpx.Request) -> httpx.Response:
359+
if request.url.path == f"/task-instances/{ti_id}/heartbeat":
360+
return httpx.Response(
361+
status_code=409,
362+
json={
363+
"detail": {
364+
"reason": "not_running",
365+
"message": "TI is no longer in the 'running' state. Task state might be externally set and task should terminate",
366+
"current_state": "failed",
367+
}
368+
},
369+
)
370+
if request.url.path == f"/task-instances/{ti_id}/run":
371+
return httpx.Response(200, json=make_ti_context_dict())
372+
return httpx.Response(status_code=204)
373+
374+
def subprocess_main():
375+
# Ensure we follow the "protocol" and get the startup message before we do anything
376+
CommsDecoder()._get_response()
377+
378+
class CustomOperator(BaseOperator):
379+
def execute(self, context):
380+
for i in range(1000):
381+
print(f"Iteration {i}")
382+
sleep(1)
383+
384+
def on_kill(self) -> None:
385+
print("On kill hook called!")
386+
387+
task = CustomOperator(task_id="print-params")
388+
runtime_ti = create_runtime_ti(
389+
dag_id="c",
390+
task=task,
391+
conf={
392+
"x": 3,
393+
"text": "Hello World!",
394+
"flag": False,
395+
"a_simple_list": ["one", "two", "three", "actually one value is made per line"],
396+
},
397+
)
398+
run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock())
399+
400+
assert os.getpid() != main_pid
401+
os.kill(os.getpid(), signal.SIGTERM)
402+
# Ensure that the signal is serviced before we finish and exit the subprocess.
403+
sleep(0.5)
404+
405+
proc = ActivitySubprocess.start(
406+
dag_rel_path=os.devnull,
407+
bundle_info=FAKE_BUNDLE,
408+
what=TaskInstance(
409+
id=ti_id,
410+
task_id="b",
411+
dag_id="c",
412+
run_id="d",
413+
try_number=1,
414+
dag_version_id=uuid7(),
415+
),
416+
client=make_client(transport=httpx.MockTransport(handle_request)),
417+
target=subprocess_main,
418+
)
419+
420+
proc.wait()
421+
captured = capfd.readouterr()
422+
assert "On kill hook called!" in captured.out
423+
344424
def test_subprocess_sigkilled(self, client_with_ti_start):
345425
main_pid = os.getpid()
346426

0 commit comments

Comments
 (0)