Skip to content

Commit 7a39b55

Browse files
Lee-WRoyLee1224
authored andcommitted
fix(hitl): handle hitl details when task instance is retried (apache#53824)
1 parent 3f710ae commit 7a39b55

File tree

5 files changed

+110
-71
lines changed
  • airflow-core
    • src/airflow/api_fastapi/execution_api/routes
    • tests/unit/api_fastapi/execution_api/versions/head
  • providers/standard/src/airflow/providers/standard/operators
  • task-sdk

5 files changed

+110
-71
lines changed

airflow-core/src/airflow/api_fastapi/execution_api/routes/hitl.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,32 +40,49 @@
4040
"/{task_instance_id}",
4141
status_code=status.HTTP_201_CREATED,
4242
)
43-
def add_hitl_detail(
43+
def upsert_hitl_detail(
4444
task_instance_id: UUID,
4545
payload: HITLDetailRequest,
4646
session: SessionDep,
4747
) -> HITLDetailRequest:
48-
"""Get Human-in-the-loop detail for a specific Task Instance."""
48+
"""
49+
Create a Human-in-the-loop detail for a specific Task Instance.
50+
51+
There're 3 cases handled here.
52+
53+
1. If a HITLOperator task instance does not have a HITLDetail,
54+
a new HITLDetail is created without a response section.
55+
2. If a HITLOperator task instance has a HITLDetail but lacks a response,
56+
the existing HITLDetail is returned.
57+
This situation occurs when a task instance is cleared before a response is received.
58+
3. If a HITLOperator task instance has both a HITLDetail and a response section,
59+
the existing response is removed, and the HITLDetail is returned.
60+
This happens when a task instance is cleared after a response has been received.
61+
This design ensures that each task instance has only one HITLDetail.
62+
"""
4963
ti_id_str = str(task_instance_id)
5064
hitl_detail_model = session.scalar(select(HITLDetail).where(HITLDetail.ti_id == ti_id_str))
51-
if hitl_detail_model:
52-
raise HTTPException(
53-
status.HTTP_409_CONFLICT,
54-
f"Human-in-the-loop detail for Task Instance with id {ti_id_str} already exists.",
65+
if not hitl_detail_model:
66+
hitl_detail_model = HITLDetail(
67+
ti_id=ti_id_str,
68+
options=payload.options,
69+
subject=payload.subject,
70+
body=payload.body,
71+
defaults=payload.defaults,
72+
multiple=payload.multiple,
73+
params=payload.params,
5574
)
75+
session.add(hitl_detail_model)
76+
elif hitl_detail_model.response_received:
77+
# Cleanup the response part of HITLDetail as we only store one response for one task instance.
78+
# It normally happens after retry, we keep only the latest response.
79+
hitl_detail_model.user_id = None
80+
hitl_detail_model.response_at = None
81+
hitl_detail_model.chosen_options = None
82+
hitl_detail_model.params_input = {}
83+
session.add(hitl_detail_model)
5684

57-
hitl_detail = HITLDetail(
58-
ti_id=ti_id_str,
59-
options=payload.options,
60-
subject=payload.subject,
61-
body=payload.body,
62-
defaults=payload.defaults,
63-
multiple=payload.multiple,
64-
params=payload.params,
65-
)
66-
session.add(hitl_detail)
67-
session.commit()
68-
return HITLDetailRequest.model_validate(hitl_detail)
85+
return HITLDetailRequest.model_validate(hitl_detail_model)
6986

7087

7188
@router.patch("/{task_instance_id}")

airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_hitl.py

Lines changed: 69 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -16,43 +16,60 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19-
from datetime import datetime
20-
2119
import pytest
22-
import time_machine
23-
from uuid6 import uuid7
20+
from httpx import Client
2421

2522
from tests_common.test_utils.db import AIRFLOW_V_3_1_PLUS
2623

2724
if not AIRFLOW_V_3_1_PLUS:
2825
pytest.skip("Human in the loop public API compatible with Airflow >= 3.0.1", allow_module_level=True)
2926

27+
from datetime import datetime
3028
from typing import TYPE_CHECKING, Any
3129

30+
import time_machine
31+
32+
from airflow._shared.timezones.timezone import convert_to_utc
3233
from airflow.models.hitl import HITLDetail
3334

3435
if TYPE_CHECKING:
36+
from fastapi.testclient import TestClient
37+
from sqlalchemy.orm import Session
38+
3539
from airflow.models.taskinstance import TaskInstance
3640

41+
from tests_common.pytest_plugin import CreateTaskInstance
42+
3743
pytestmark = pytest.mark.db_test
38-
TI_ID = uuid7()
44+
45+
default_hitl_detail_request_kwargs: dict[str, Any] = {
46+
# ti_id decided at a later stage
47+
"subject": "This is subject",
48+
"body": "this is body",
49+
"options": ["Approve", "Reject"],
50+
"defaults": ["Approve"],
51+
"multiple": False,
52+
"params": {"input_1": 1},
53+
}
54+
expected_empty_hitl_detail_response_part: dict[str, Any] = {
55+
"response_at": None,
56+
"chosen_options": None,
57+
"user_id": None,
58+
"params_input": {},
59+
"response_received": False,
60+
}
3961

4062

4163
@pytest.fixture
42-
def sample_ti(create_task_instance) -> TaskInstance:
64+
def sample_ti(create_task_instance: CreateTaskInstance) -> TaskInstance:
4365
return create_task_instance()
4466

4567

4668
@pytest.fixture
47-
def sample_hitl_detail(session, sample_ti) -> HITLDetail:
69+
def sample_hitl_detail(session: Session, sample_ti: TaskInstance) -> HITLDetail:
4870
hitl_detail_model = HITLDetail(
4971
ti_id=sample_ti.id,
50-
options=["Approve", "Reject"],
51-
subject="This is subject",
52-
body="this is body",
53-
defaults=["Approve"],
54-
multiple=False,
55-
params={"input_1": 1},
72+
**default_hitl_detail_request_kwargs,
5673
)
5774
session.add(hitl_detail_model)
5875
session.commit()
@@ -61,54 +78,65 @@ def sample_hitl_detail(session, sample_ti) -> HITLDetail:
6178

6279

6380
@pytest.fixture
64-
def expected_sample_hitl_detail_dict(sample_ti) -> dict[str, Any]:
81+
def expected_sample_hitl_detail_dict(sample_ti: TaskInstance) -> dict[str, Any]:
6582
return {
66-
"body": "this is body",
67-
"defaults": ["Approve"],
68-
"multiple": False,
69-
"options": ["Approve", "Reject"],
70-
"params": {"input_1": 1},
71-
"params_input": {},
72-
"response_at": None,
73-
"chosen_options": None,
74-
"response_received": False,
75-
"subject": "This is subject",
7683
"ti_id": sample_ti.id,
77-
"user_id": None,
84+
**default_hitl_detail_request_kwargs,
85+
**expected_empty_hitl_detail_response_part,
7886
}
7987

8088

81-
def test_add_hitl_detail(client, create_task_instance, session) -> None:
89+
@pytest.mark.parametrize(
90+
"existing_hitl_detail_args",
91+
[
92+
None,
93+
default_hitl_detail_request_kwargs,
94+
{
95+
**default_hitl_detail_request_kwargs,
96+
**{
97+
"params_input": {"input_1": 2},
98+
"response_at": convert_to_utc(datetime(2025, 7, 3, 0, 0, 0)),
99+
"chosen_options": ["Reject"],
100+
"user_id": "Fallback to defaults",
101+
},
102+
},
103+
],
104+
ids=[
105+
"no existing hitl detail",
106+
"existing hitl detail without response",
107+
"existing hitl detail with response",
108+
],
109+
)
110+
def test_upsert_hitl_detail(
111+
client: TestClient,
112+
create_task_instance: CreateTaskInstance,
113+
session: Session,
114+
existing_hitl_detail_args: dict[str, Any],
115+
) -> None:
82116
ti = create_task_instance()
83117
session.commit()
84118

119+
if existing_hitl_detail_args:
120+
session.add(HITLDetail(ti_id=ti.id, **existing_hitl_detail_args))
121+
session.commit()
122+
85123
response = client.post(
86124
f"/execution/hitl-details/{ti.id}",
87125
json={
88126
"ti_id": ti.id,
89-
"options": ["Approve", "Reject"],
90-
"subject": "This is subject",
91-
"body": "this is body",
92-
"defaults": ["Approve"],
93-
"multiple": False,
94-
"params": {"input_1": 1},
127+
**default_hitl_detail_request_kwargs,
95128
},
96129
)
97130
assert response.status_code == 201
98131
assert response.json() == {
99132
"ti_id": ti.id,
100-
"options": ["Approve", "Reject"],
101-
"subject": "This is subject",
102-
"body": "this is body",
103-
"defaults": ["Approve"],
104-
"multiple": False,
105-
"params": {"input_1": 1},
133+
**default_hitl_detail_request_kwargs,
106134
}
107135

108136

109137
@time_machine.travel(datetime(2025, 7, 3, 0, 0, 0), tick=False)
110138
@pytest.mark.usefixtures("sample_hitl_detail")
111-
def test_update_hitl_detail(client, sample_ti) -> None:
139+
def test_update_hitl_detail(client: Client, sample_ti: TaskInstance) -> None:
112140
response = client.patch(
113141
f"/execution/hitl-details/{sample_ti.id}",
114142
json={
@@ -128,13 +156,7 @@ def test_update_hitl_detail(client, sample_ti) -> None:
128156

129157

130158
@pytest.mark.usefixtures("sample_hitl_detail")
131-
def test_get_hitl_detail(client, sample_ti) -> None:
159+
def test_get_hitl_detail(client: Client, sample_ti: TaskInstance) -> None:
132160
response = client.get(f"/execution/hitl-details/{sample_ti.id}")
133161
assert response.status_code == 200
134-
assert response.json() == {
135-
"params_input": {},
136-
"response_at": None,
137-
"chosen_options": None,
138-
"response_received": False,
139-
"user_id": None,
140-
}
162+
assert response.json() == expected_empty_hitl_detail_response_part

providers/standard/src/airflow/providers/standard/operators/hitl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from airflow.providers.standard.triggers.hitl import HITLTrigger, HITLTriggerEventSuccessPayload
3535
from airflow.providers.standard.utils.skipmixin import SkipMixin
3636
from airflow.sdk.definitions.param import ParamsDict
37-
from airflow.sdk.execution_time.hitl import add_hitl_detail
37+
from airflow.sdk.execution_time.hitl import upsert_hitl_detail
3838

3939
if TYPE_CHECKING:
4040
from airflow.sdk.definitions.context import Context
@@ -98,7 +98,7 @@ def execute(self, context: Context):
9898
"""Add a Human-in-the-loop Response and then defer to HITLTrigger and wait for user input."""
9999
ti_id = context["task_instance"].id
100100
# Write Human-in-the-loop input request to DB
101-
add_hitl_detail(
101+
upsert_hitl_detail(
102102
ti_id=ti_id,
103103
options=self.options,
104104
subject=self.subject,

task-sdk/src/airflow/sdk/execution_time/hitl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from airflow.api_fastapi.execution_api.datamodels.hitl import HITLDetailResponse
3131

3232

33-
def add_hitl_detail(
33+
def upsert_hitl_detail(
3434
ti_id: UUID,
3535
options: list[str],
3636
subject: str,

task-sdk/tests/task_sdk/execution_time/test_hitl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,16 @@
2323
from airflow.sdk.api.datamodels._generated import HITLDetailResponse
2424
from airflow.sdk.execution_time.comms import CreateHITLDetailPayload
2525
from airflow.sdk.execution_time.hitl import (
26-
add_hitl_detail,
2726
get_hitl_detail_content_detail,
2827
update_htil_detail_response,
28+
upsert_hitl_detail,
2929
)
3030

3131
TI_ID = uuid7()
3232

3333

34-
def test_add_hitl_detail(mock_supervisor_comms) -> None:
35-
add_hitl_detail(
34+
def test_upsert_hitl_detail(mock_supervisor_comms) -> None:
35+
upsert_hitl_detail(
3636
ti_id=TI_ID,
3737
options=["Approve", "Reject"],
3838
subject="Subject",

0 commit comments

Comments
 (0)