Skip to content

Commit 0c3dcf1

Browse files
amoghrajeshkaxil
authored andcommitted
Add user impersonation (run_as_user) support for task execution (#51780)
(cherry picked from commit c8343d9)
1 parent 4b39087 commit 0c3dcf1

File tree

3 files changed

+129
-6
lines changed

3 files changed

+129
-6
lines changed

airflow-core/src/airflow/settings.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,9 @@ def initialize():
613613
# The webservers import this file from models.py with the default settings.
614614

615615
if not os.environ.get("PYTHON_OPERATORS_VIRTUAL_ENV_MODE", None):
616-
configure_orm()
616+
is_worker = os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1"
617+
if not is_worker:
618+
configure_orm()
617619
configure_action_logging()
618620

619621
# mask the sensitive_config_values

task-sdk/src/airflow/sdk/execution_time/task_runner.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@
3535
import attrs
3636
import lazy_object_proxy
3737
import structlog
38-
from pydantic import AwareDatetime, ConfigDict, Field, JsonValue
38+
from pydantic import AwareDatetime, ConfigDict, Field, JsonValue, TypeAdapter
3939

40+
from airflow.configuration import conf
4041
from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock
4142
from airflow.dag_processing.bundles.manager import DagBundlesManager
4243
from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException
@@ -97,6 +98,7 @@
9798
)
9899
from airflow.sdk.execution_time.xcom import XCom
99100
from airflow.utils.net import get_hostname
101+
from airflow.utils.platform import getuser
100102
from airflow.utils.timezone import coerce_datetime
101103

102104
if TYPE_CHECKING:
@@ -623,6 +625,7 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance:
623625
# accessible wherever needed during task execution without modifying every layer of the call stack.
624626
SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor]
625627

628+
626629
# State machine!
627630
# 1. Start up (receive details from supervisor)
628631
# 2. Execution (run task code, possibly send requests)
@@ -632,13 +635,18 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance:
632635
def startup() -> tuple[RuntimeTaskInstance, Context, Logger]:
633636
# The parent sends us a StartupDetails message un-prompted. After this, every single message is only sent
634637
# in response to us sending a request.
635-
msg = SUPERVISOR_COMMS._get_response()
638+
log = structlog.get_logger(logger_name="task")
639+
640+
if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and os.environ.get("_AIRFLOW__STARTUP_MSG"):
641+
# entrypoint of re-exec process
642+
msg = TypeAdapter(StartupDetails).validate_json(os.environ["_AIRFLOW__STARTUP_MSG"])
643+
log.debug("Using serialized startup message from environment", msg=msg)
644+
else:
645+
# normal entry point
646+
msg = SUPERVISOR_COMMS._get_response() # type: ignore[assignment]
636647

637648
if not isinstance(msg, StartupDetails):
638649
raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}")
639-
640-
log = structlog.get_logger(logger_name="task")
641-
642650
# setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021
643651
os_type = sys.platform
644652
if os_type == "darwin":
@@ -657,6 +665,34 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]:
657665
ti = parse(msg, log)
658666
log.debug("DAG file parsed", file=msg.dag_rel_path)
659667

668+
run_as_user = getattr(ti.task, "run_as_user", None) or conf.get(
669+
"core", "default_impersonation", fallback=None
670+
)
671+
672+
if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") != "1" and run_as_user and run_as_user != getuser():
673+
# enters here for re-exec process
674+
os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1"
675+
# store startup message in environment for re-exec process
676+
os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json()
677+
os.set_inheritable(SUPERVISOR_COMMS.socket.fileno(), True)
678+
679+
# Import main directly from the module instead of re-executing the file.
680+
# This ensures that when other parts modules import
681+
# airflow.sdk.execution_time.task_runner, they get the same module instance
682+
# with the properly initialized SUPERVISOR_COMMS global variable.
683+
# If we re-executed the module with `python -m`, it would load as __main__ and future
684+
# imports would get a fresh copy without the initialized globals.
685+
rexec_python_code = "from airflow.sdk.execution_time.task_runner import main; main()"
686+
cmd = ["sudo", "-E", "-H", "-u", run_as_user, sys.executable, "-c", rexec_python_code]
687+
log.info(
688+
"Running command",
689+
command=cmd,
690+
)
691+
os.execvp("sudo", cmd)
692+
693+
# ideally, we should never reach here, but if we do, we should return None, None, None
694+
return None, None, None
695+
660696
return ti, ti.get_template_context(), log
661697

662698

task-sdk/tests/task_sdk/execution_time/test_task_runner.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,91 @@ def execute(self, context):
643643
mock_supervisor_comms.assert_has_calls(expected_calls)
644644

645645

646+
@patch("os.execvp")
647+
@patch("os.set_inheritable")
648+
def test_task_run_with_user_impersonation(
649+
mock_set_inheritable, mock_execvp, mocked_parse, make_ti_context, time_machine, mock_supervisor_comms
650+
):
651+
class CustomOperator(BaseOperator):
652+
def execute(self, context):
653+
print("Hi from CustomOperator!")
654+
655+
task = CustomOperator(task_id="impersonation_task", run_as_user="airflowuser")
656+
instant = timezone.datetime(2024, 12, 3, 10, 0)
657+
658+
what = StartupDetails(
659+
ti=TaskInstance(
660+
id=uuid7(),
661+
task_id="impersonation_task",
662+
dag_id="basic_dag",
663+
run_id="c",
664+
try_number=1,
665+
),
666+
dag_rel_path="",
667+
bundle_info=FAKE_BUNDLE,
668+
ti_context=make_ti_context(),
669+
start_date=timezone.utcnow(),
670+
)
671+
672+
mocked_parse(what, "basic_dag", task)
673+
time_machine.move_to(instant, tick=False)
674+
675+
mock_supervisor_comms._get_response.return_value = what
676+
mock_supervisor_comms.socket.fileno.return_value = 42
677+
678+
with mock.patch.dict(os.environ, {}, clear=True):
679+
startup()
680+
681+
assert os.environ["_AIRFLOW__REEXECUTED_PROCESS"] == "1"
682+
assert "_AIRFLOW__STARTUP_MSG" in os.environ
683+
684+
mock_set_inheritable.assert_called_once_with(42, True)
685+
actual_cmd = mock_execvp.call_args.args[1]
686+
687+
assert actual_cmd[:5] == ["sudo", "-E", "-H", "-u", "airflowuser"]
688+
assert "python -c" in actual_cmd[5] + " " + actual_cmd[6]
689+
assert actual_cmd[7] == "from airflow.sdk.execution_time.task_runner import main; main()"
690+
691+
692+
@patch("airflow.sdk.execution_time.task_runner.getuser")
693+
def test_task_run_with_user_impersonation_default_user(
694+
mock_get_user, mocked_parse, make_ti_context, time_machine, mock_supervisor_comms
695+
):
696+
class CustomOperator(BaseOperator):
697+
def execute(self, context):
698+
print("Hi from CustomOperator!")
699+
700+
task = CustomOperator(task_id="impersonation_task", run_as_user="default_user")
701+
instant = timezone.datetime(2024, 12, 3, 10, 0)
702+
703+
what = StartupDetails(
704+
ti=TaskInstance(
705+
id=uuid7(),
706+
task_id="impersonation_task",
707+
dag_id="basic_dag",
708+
run_id="c",
709+
try_number=1,
710+
),
711+
dag_rel_path="",
712+
bundle_info=FAKE_BUNDLE,
713+
ti_context=make_ti_context(),
714+
start_date=timezone.utcnow(),
715+
)
716+
717+
mocked_parse(what, "basic_dag", task)
718+
time_machine.move_to(instant, tick=False)
719+
720+
mock_supervisor_comms._get_response.return_value = what
721+
mock_supervisor_comms.socket.fileno.return_value = 42
722+
mock_get_user.return_value = "default_user"
723+
724+
with mock.patch.dict(os.environ, {}, clear=True):
725+
startup()
726+
727+
assert "_AIRFLOW__REEXECUTED_PROCESS" not in os.environ
728+
assert "_AIRFLOW__STARTUP_MSG" not in os.environ
729+
730+
646731
@pytest.mark.parametrize(
647732
["command", "rendered_command"],
648733
[

0 commit comments

Comments
 (0)