Skip to content

Commit c8343d9

Browse files
authored
Add user impersonation (run_as_user) support for task execution (#51780)
1 parent b8b7f4a commit c8343d9

File tree

3 files changed

+129
-6
lines changed

3 files changed

+129
-6
lines changed

airflow-core/src/airflow/settings.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,9 @@ def initialize():
616616
# The webservers import this file from models.py with the default settings.
617617

618618
if not os.environ.get("PYTHON_OPERATORS_VIRTUAL_ENV_MODE", None):
619-
configure_orm()
619+
is_worker = os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1"
620+
if not is_worker:
621+
configure_orm()
620622
configure_action_logging()
621623

622624
# mask the sensitive_config_values

task-sdk/src/airflow/sdk/execution_time/task_runner.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@
3535
import attrs
3636
import lazy_object_proxy
3737
import structlog
38-
from pydantic import AwareDatetime, ConfigDict, Field, JsonValue
38+
from pydantic import AwareDatetime, ConfigDict, Field, JsonValue, TypeAdapter
3939

40+
from airflow.configuration import conf
4041
from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock
4142
from airflow.dag_processing.bundles.manager import DagBundlesManager
4243
from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException
@@ -97,6 +98,7 @@
9798
)
9899
from airflow.sdk.execution_time.xcom import XCom
99100
from airflow.utils.net import get_hostname
101+
from airflow.utils.platform import getuser
100102
from airflow.utils.timezone import coerce_datetime
101103

102104
if TYPE_CHECKING:
@@ -642,6 +644,7 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance:
642644
# accessible wherever needed during task execution without modifying every layer of the call stack.
643645
SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor]
644646

647+
645648
# State machine!
646649
# 1. Start up (receive details from supervisor)
647650
# 2. Execution (run task code, possibly send requests)
@@ -651,13 +654,18 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance:
651654
def startup() -> tuple[RuntimeTaskInstance, Context, Logger]:
652655
# The parent sends us a StartupDetails message un-prompted. After this, every single message is only sent
653656
# in response to us sending a request.
654-
msg = SUPERVISOR_COMMS._get_response()
657+
log = structlog.get_logger(logger_name="task")
658+
659+
if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and os.environ.get("_AIRFLOW__STARTUP_MSG"):
660+
# entrypoint of re-exec process
661+
msg = TypeAdapter(StartupDetails).validate_json(os.environ["_AIRFLOW__STARTUP_MSG"])
662+
log.debug("Using serialized startup message from environment", msg=msg)
663+
else:
664+
# normal entry point
665+
msg = SUPERVISOR_COMMS._get_response() # type: ignore[assignment]
655666

656667
if not isinstance(msg, StartupDetails):
657668
raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}")
658-
659-
log = structlog.get_logger(logger_name="task")
660-
661669
# setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021
662670
os_type = sys.platform
663671
if os_type == "darwin":
@@ -677,6 +685,34 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]:
677685
ti.log_url = get_log_url_from_ti(ti)
678686
log.debug("DAG file parsed", file=msg.dag_rel_path)
679687

688+
run_as_user = getattr(ti.task, "run_as_user", None) or conf.get(
689+
"core", "default_impersonation", fallback=None
690+
)
691+
692+
if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") != "1" and run_as_user and run_as_user != getuser():
693+
# enters here for re-exec process
694+
os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1"
695+
# store startup message in environment for re-exec process
696+
os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json()
697+
os.set_inheritable(SUPERVISOR_COMMS.socket.fileno(), True)
698+
699+
# Import main directly from the module instead of re-executing the file.
700+
# This ensures that when other parts modules import
701+
# airflow.sdk.execution_time.task_runner, they get the same module instance
702+
# with the properly initialized SUPERVISOR_COMMS global variable.
703+
# If we re-executed the module with `python -m`, it would load as __main__ and future
704+
# imports would get a fresh copy without the initialized globals.
705+
rexec_python_code = "from airflow.sdk.execution_time.task_runner import main; main()"
706+
cmd = ["sudo", "-E", "-H", "-u", run_as_user, sys.executable, "-c", rexec_python_code]
707+
log.info(
708+
"Running command",
709+
command=cmd,
710+
)
711+
os.execvp("sudo", cmd)
712+
713+
# ideally, we should never reach here, but if we do, we should return None, None, None
714+
return None, None, None
715+
680716
return ti, ti.get_template_context(), log
681717

682718

task-sdk/tests/task_sdk/execution_time/test_task_runner.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,91 @@ def execute(self, context):
644644
mock_supervisor_comms.assert_has_calls(expected_calls)
645645

646646

647+
@patch("os.execvp")
648+
@patch("os.set_inheritable")
649+
def test_task_run_with_user_impersonation(
650+
mock_set_inheritable, mock_execvp, mocked_parse, make_ti_context, time_machine, mock_supervisor_comms
651+
):
652+
class CustomOperator(BaseOperator):
653+
def execute(self, context):
654+
print("Hi from CustomOperator!")
655+
656+
task = CustomOperator(task_id="impersonation_task", run_as_user="airflowuser")
657+
instant = timezone.datetime(2024, 12, 3, 10, 0)
658+
659+
what = StartupDetails(
660+
ti=TaskInstance(
661+
id=uuid7(),
662+
task_id="impersonation_task",
663+
dag_id="basic_dag",
664+
run_id="c",
665+
try_number=1,
666+
),
667+
dag_rel_path="",
668+
bundle_info=FAKE_BUNDLE,
669+
ti_context=make_ti_context(),
670+
start_date=timezone.utcnow(),
671+
)
672+
673+
mocked_parse(what, "basic_dag", task)
674+
time_machine.move_to(instant, tick=False)
675+
676+
mock_supervisor_comms._get_response.return_value = what
677+
mock_supervisor_comms.socket.fileno.return_value = 42
678+
679+
with mock.patch.dict(os.environ, {}, clear=True):
680+
startup()
681+
682+
assert os.environ["_AIRFLOW__REEXECUTED_PROCESS"] == "1"
683+
assert "_AIRFLOW__STARTUP_MSG" in os.environ
684+
685+
mock_set_inheritable.assert_called_once_with(42, True)
686+
actual_cmd = mock_execvp.call_args.args[1]
687+
688+
assert actual_cmd[:5] == ["sudo", "-E", "-H", "-u", "airflowuser"]
689+
assert "python -c" in actual_cmd[5] + " " + actual_cmd[6]
690+
assert actual_cmd[7] == "from airflow.sdk.execution_time.task_runner import main; main()"
691+
692+
693+
@patch("airflow.sdk.execution_time.task_runner.getuser")
694+
def test_task_run_with_user_impersonation_default_user(
695+
mock_get_user, mocked_parse, make_ti_context, time_machine, mock_supervisor_comms
696+
):
697+
class CustomOperator(BaseOperator):
698+
def execute(self, context):
699+
print("Hi from CustomOperator!")
700+
701+
task = CustomOperator(task_id="impersonation_task", run_as_user="default_user")
702+
instant = timezone.datetime(2024, 12, 3, 10, 0)
703+
704+
what = StartupDetails(
705+
ti=TaskInstance(
706+
id=uuid7(),
707+
task_id="impersonation_task",
708+
dag_id="basic_dag",
709+
run_id="c",
710+
try_number=1,
711+
),
712+
dag_rel_path="",
713+
bundle_info=FAKE_BUNDLE,
714+
ti_context=make_ti_context(),
715+
start_date=timezone.utcnow(),
716+
)
717+
718+
mocked_parse(what, "basic_dag", task)
719+
time_machine.move_to(instant, tick=False)
720+
721+
mock_supervisor_comms._get_response.return_value = what
722+
mock_supervisor_comms.socket.fileno.return_value = 42
723+
mock_get_user.return_value = "default_user"
724+
725+
with mock.patch.dict(os.environ, {}, clear=True):
726+
startup()
727+
728+
assert "_AIRFLOW__REEXECUTED_PROCESS" not in os.environ
729+
assert "_AIRFLOW__STARTUP_MSG" not in os.environ
730+
731+
647732
@pytest.mark.parametrize(
648733
["command", "rendered_command"],
649734
[

0 commit comments

Comments
 (0)