Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
48f33a0
Switch the Supervisor/task process from line-based to length-prefixed
ashb Jun 10, 2025
22a44cb
fixup! Switch the Supervisor/task process from line-based to length-p…
ashb Jun 13, 2025
1372214
User impersonation wip
amoghrajesh Jun 13, 2025
dcc491c
xcom push doesnt work -- rest is ok
amoghrajesh Jun 13, 2025
d486082
cleaning up the code
amoghrajesh Jun 16, 2025
595b550
using getattr instead
amoghrajesh Jun 16, 2025
4975725
Switch the Supervisor/task process from line-based to length-prefixed
ashb Jun 10, 2025
c014ebe
Merge branch 'rework-tasksdk-supervisor-comms-protocol' into user-imp…
amoghrajesh Jun 16, 2025
8ad4db5
Switch the Supervisor/task process from line-based to length-prefixed
ashb Jun 10, 2025
05c78f1
Deal with compat in tests
ashb Jun 16, 2025
f45647a
Code review
ashb Jun 16, 2025
a1c0dd2
Merge branch 'rework-tasksdk-supervisor-comms-protocol' into user-imp…
amoghrajesh Jun 17, 2025
6b37dcd
adding a testin task runner
amoghrajesh Jun 17, 2025
f3b7239
fixing mock
amoghrajesh Jun 17, 2025
5ba66b5
adding unit tests
amoghrajesh Jun 17, 2025
a5c2a1e
Merge branch 'main' into user-impersonation-reworked
amoghrajesh Jun 17, 2025
54ee8c5
rebase errors
amoghrajesh Jun 17, 2025
e890bad
nuke getattr and use function instead
amoghrajesh Jun 18, 2025
b84df96
removing unwanted code
amoghrajesh Jun 18, 2025
5a7f6d6
fixing tests
amoghrajesh Jun 18, 2025
5a8e13e
Apply suggestions from code review
amoghrajesh Jun 18, 2025
010d2b2
nits from ash
amoghrajesh Jun 18, 2025
6ee1cb0
nits from ash
amoghrajesh Jun 18, 2025
706542f
addressing an edge case
amoghrajesh Jun 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion airflow-core/src/airflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,9 @@ def initialize():
# The webservers import this file from models.py with the default settings.

if not os.environ.get("PYTHON_OPERATORS_VIRTUAL_ENV_MODE", None):
configure_orm()
is_worker = os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1"
if not is_worker:
configure_orm()
configure_action_logging()

# mask the sensitive_config_values
Expand Down
46 changes: 41 additions & 5 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
import attrs
import lazy_object_proxy
import structlog
from pydantic import AwareDatetime, ConfigDict, Field, JsonValue
from pydantic import AwareDatetime, ConfigDict, Field, JsonValue, TypeAdapter

from airflow.configuration import conf
from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock
from airflow.dag_processing.bundles.manager import DagBundlesManager
from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException
Expand Down Expand Up @@ -97,6 +98,7 @@
)
from airflow.sdk.execution_time.xcom import XCom
from airflow.utils.net import get_hostname
from airflow.utils.platform import getuser
from airflow.utils.timezone import coerce_datetime

if TYPE_CHECKING:
Expand Down Expand Up @@ -642,6 +644,7 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance:
# accessible wherever needed during task execution without modifying every layer of the call stack.
SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor]


# State machine!
# 1. Start up (receive details from supervisor)
# 2. Execution (run task code, possibly send requests)
Expand All @@ -651,13 +654,18 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance:
def startup() -> tuple[RuntimeTaskInstance, Context, Logger]:
# The parent sends us a StartupDetails message un-prompted. After this, every single message is only sent
# in response to us sending a request.
msg = SUPERVISOR_COMMS._get_response()
log = structlog.get_logger(logger_name="task")

if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and os.environ.get("_AIRFLOW__STARTUP_MSG"):
# entrypoint of re-exec process
msg = TypeAdapter(StartupDetails).validate_json(os.environ["_AIRFLOW__STARTUP_MSG"])
log.debug("Using serialized startup message from environment", msg=msg)
else:
# normal entry point
msg = SUPERVISOR_COMMS._get_response() # type: ignore[assignment]

if not isinstance(msg, StartupDetails):
raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}")

log = structlog.get_logger(logger_name="task")

# setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021
os_type = sys.platform
if os_type == "darwin":
Expand All @@ -677,6 +685,34 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]:
ti.log_url = get_log_url_from_ti(ti)
log.debug("DAG file parsed", file=msg.dag_rel_path)

run_as_user = getattr(ti.task, "run_as_user", None) or conf.get(
"core", "default_impersonation", fallback=None
)

if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") != "1" and run_as_user and run_as_user != getuser():
# enters here for re-exec process
os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1"
# store startup message in environment for re-exec process
os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json()
os.set_inheritable(SUPERVISOR_COMMS.socket.fileno(), True)

# Import main directly from the module instead of re-executing the file.
# This ensures that when other parts modules import
# airflow.sdk.execution_time.task_runner, they get the same module instance
# with the properly initialized SUPERVISOR_COMMS global variable.
# If we re-executed the module with `python -m`, it would load as __main__ and future
# imports would get a fresh copy without the initialized globals.
rexec_python_code = "from airflow.sdk.execution_time.task_runner import main; main()"
cmd = ["sudo", "-E", "-H", "-u", run_as_user, sys.executable, "-c", rexec_python_code]
log.info(
"Running command",
command=cmd,
)
os.execvp("sudo", cmd)

# ideally, we should never reach here, but if we do, we should return None, None, None
return None, None, None

return ti, ti.get_template_context(), log


Expand Down
85 changes: 85 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,91 @@ def execute(self, context):
mock_supervisor_comms.assert_has_calls(expected_calls)


@patch("os.execvp")
@patch("os.set_inheritable")
def test_task_run_with_user_impersonation(
mock_set_inheritable, mock_execvp, mocked_parse, make_ti_context, time_machine, mock_supervisor_comms
):
class CustomOperator(BaseOperator):
def execute(self, context):
print("Hi from CustomOperator!")

task = CustomOperator(task_id="impersonation_task", run_as_user="airflowuser")
instant = timezone.datetime(2024, 12, 3, 10, 0)

what = StartupDetails(
ti=TaskInstance(
id=uuid7(),
task_id="impersonation_task",
dag_id="basic_dag",
run_id="c",
try_number=1,
),
dag_rel_path="",
bundle_info=FAKE_BUNDLE,
ti_context=make_ti_context(),
start_date=timezone.utcnow(),
)

mocked_parse(what, "basic_dag", task)
time_machine.move_to(instant, tick=False)

mock_supervisor_comms._get_response.return_value = what
mock_supervisor_comms.socket.fileno.return_value = 42

with mock.patch.dict(os.environ, {}, clear=True):
startup()

assert os.environ["_AIRFLOW__REEXECUTED_PROCESS"] == "1"
assert "_AIRFLOW__STARTUP_MSG" in os.environ

mock_set_inheritable.assert_called_once_with(42, True)
actual_cmd = mock_execvp.call_args.args[1]

assert actual_cmd[:5] == ["sudo", "-E", "-H", "-u", "airflowuser"]
assert "python -c" in actual_cmd[5] + " " + actual_cmd[6]
assert actual_cmd[7] == "from airflow.sdk.execution_time.task_runner import main; main()"


@patch("airflow.sdk.execution_time.task_runner.getuser")
def test_task_run_with_user_impersonation_default_user(
mock_get_user, mocked_parse, make_ti_context, time_machine, mock_supervisor_comms
):
class CustomOperator(BaseOperator):
def execute(self, context):
print("Hi from CustomOperator!")

task = CustomOperator(task_id="impersonation_task", run_as_user="default_user")
instant = timezone.datetime(2024, 12, 3, 10, 0)

what = StartupDetails(
ti=TaskInstance(
id=uuid7(),
task_id="impersonation_task",
dag_id="basic_dag",
run_id="c",
try_number=1,
),
dag_rel_path="",
bundle_info=FAKE_BUNDLE,
ti_context=make_ti_context(),
start_date=timezone.utcnow(),
)

mocked_parse(what, "basic_dag", task)
time_machine.move_to(instant, tick=False)

mock_supervisor_comms._get_response.return_value = what
mock_supervisor_comms.socket.fileno.return_value = 42
mock_get_user.return_value = "default_user"

with mock.patch.dict(os.environ, {}, clear=True):
startup()

assert "_AIRFLOW__REEXECUTED_PROCESS" not in os.environ
assert "_AIRFLOW__STARTUP_MSG" not in os.environ


@pytest.mark.parametrize(
["command", "rendered_command"],
[
Expand Down