Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,27 @@

REMOTE_LOGGING: bool = conf.getboolean("logging", "remote_logging")
REMOTE_TASK_LOG: RemoteLogIO | None = None
DEFAULT_REMOTE_CONN_ID: str | None = None


def _default_conn_name_from(mod_path, hook_name):
# Try to set the default conn name from a hook, but don't error if something goes wrong at runtime
from importlib import import_module

global DEFAULT_REMOTE_CONN_ID

try:
mod = import_module(mod_path)

hook = getattr(mod, hook_name)

DEFAULT_REMOTE_CONN_ID = getattr(hook, "default_conn_name")
except Exception:
# Lets error in tests though!
if "PYTEST_CURRENT_TEST" in os.environ:
raise
return None


if REMOTE_LOGGING:
ELASTICSEARCH_HOST: str | None = conf.get("elasticsearch", "HOST")
Expand All @@ -151,6 +172,7 @@
if remote_base_log_folder.startswith("s3://"):
from airflow.providers.amazon.aws.log.s3_task_handler import S3RemoteLogIO

_default_conn_name_from("airflow.providers.amazon.aws.hooks.s3", "S3Hook")
REMOTE_TASK_LOG = S3RemoteLogIO(
**(
{
Expand All @@ -166,6 +188,7 @@
elif remote_base_log_folder.startswith("cloudwatch://"):
from airflow.providers.amazon.aws.log.cloudwatch_task_handler import CloudWatchRemoteLogIO

_default_conn_name_from("airflow.providers.amazon.aws.hooks.logs", "AwsLogsHook")
url_parts = urlsplit(remote_base_log_folder)
REMOTE_TASK_LOG = CloudWatchRemoteLogIO(
**(
Expand All @@ -182,6 +205,7 @@
elif remote_base_log_folder.startswith("gs://"):
from airflow.providers.google.cloud.log.gcs_task_handler import GCSRemoteLogIO

_default_conn_name_from("airflow.providers.google.cloud.hooks.gcs", "GCSHook")
key_path = conf.get_mandatory_value("logging", "google_key_path", fallback=None)

REMOTE_TASK_LOG = GCSRemoteLogIO(
Expand All @@ -199,6 +223,7 @@
elif remote_base_log_folder.startswith("wasb"):
from airflow.providers.microsoft.azure.log.wasb_task_handler import WasbRemoteLogIO

_default_conn_name_from("airflow.providers.microsoft.azure.hooks.wasb", "WasbHook")
wasb_log_container = conf.get_mandatory_value(
"azure_remote_logging", "remote_wasb_log_container", fallback="airflow-logs"
)
Expand Down Expand Up @@ -232,6 +257,8 @@
elif remote_base_log_folder.startswith("oss://"):
from airflow.providers.alibaba.cloud.log.oss_task_handler import OSSRemoteLogIO

_default_conn_name_from("airflow.providers.alibaba.cloud.hooks.oss", "OSSHook")

REMOTE_TASK_LOG = OSSRemoteLogIO(
**(
{
Expand All @@ -246,6 +273,8 @@
elif remote_base_log_folder.startswith("hdfs://"):
from airflow.providers.apache.hdfs.log.hdfs_task_handler import HdfsRemoteLogIO

_default_conn_name_from("airflow.providers.apache.hdfs.hooks.webhdfs", "WebHDFSHook")

REMOTE_TASK_LOG = HdfsRemoteLogIO(
**(
{
Expand Down
10 changes: 6 additions & 4 deletions airflow-core/src/airflow/logging_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@


REMOTE_TASK_LOG: RemoteLogIO | None
DEFAULT_REMOTE_CONN_ID: str | None = None


def __getattr__(name: str):
Expand All @@ -44,7 +45,7 @@ def __getattr__(name: str):

def load_logging_config() -> tuple[dict[str, Any], str]:
"""Configure & Validate Airflow Logging."""
global REMOTE_TASK_LOG
global REMOTE_TASK_LOG, DEFAULT_REMOTE_CONN_ID
fallback = "airflow.config_templates.airflow_local_settings.DEFAULT_LOGGING_CONFIG"
logging_class_path = conf.get("logging", "logging_config_class", fallback=fallback)

Expand All @@ -70,10 +71,11 @@ def load_logging_config() -> tuple[dict[str, Any], str]:
f"to: {type(err).__name__}:{err}"
)
else:
mod = logging_class_path.rsplit(".", 1)[0]
modpath = logging_class_path.rsplit(".", 1)[0]
try:
remote_task_log = import_string(f"{mod}.REMOTE_TASK_LOG")
REMOTE_TASK_LOG = remote_task_log
mod = import_string(modpath)
REMOTE_TASK_LOG = getattr(mod, "REMOTE_TASK_LOG")
DEFAULT_REMOTE_CONN_ID = getattr(mod, "DEFAULT_REMOTE_CONN_ID", None)
except Exception as err:
log.info("Remote task logs will not be available due to an error: %s", err)

Expand Down
106 changes: 90 additions & 16 deletions task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from __future__ import annotations

import atexit
import contextlib
import io
import logging
import os
Expand Down Expand Up @@ -127,6 +128,7 @@
from structlog.typing import FilteringBoundLogger, WrappedLogger

from airflow.executors.workloads import BundleInfo
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
from airflow.secrets import BaseSecretsBackend
from airflow.typing_compat import Self
Expand Down Expand Up @@ -1615,6 +1617,93 @@ def ensure_secrets_backend_loaded() -> list[BaseSecretsBackend]:
return backends


@contextlib.contextmanager
def _remote_logging_conn(client: Client):
"""
Pre-fetch the needed remote logging connection.

If a remote logger is in use, and has the logging/remote_logging option set, we try to fetch the
connection it needs, now, directly from the API client, and store it in an env var, so that when the logging
hook tries to get the connection it
can find it easily from the env vars

This is needed as the BaseHook.get_connection looks for SUPERVISOR_COMMS, but we are still in the
supervisor process when this is needed, so that doesn't exist yet.
"""
from airflow.sdk.log import load_remote_conn_id, load_remote_log_handler

if load_remote_log_handler() is None or not (conn_id := load_remote_conn_id()):
# Nothing to do
yield
return

# Since we need to use the API Client directly, we can't use Connection.get as that would try to use
# SUPERVISOR_COMMS

# TODO: Store in the SecretsCache if its enabled - see #48858

def _get_conn() -> Connection | None:
backends = ensure_secrets_backend_loaded()
for secrets_backend in backends:
try:
conn = secrets_backend.get_connection(conn_id=conn_id)
if conn:
return conn
except Exception:
log.exception(
"Unable to retrieve connection from secrets backend (%s). "
"Checking subsequent secrets backend.",
type(secrets_backend).__name__,
)

conn = client.connections.get(conn_id)
if isinstance(conn, ConnectionResponse):
conn_result = ConnectionResult.from_conn_response(conn)
from airflow.sdk.definitions.connection import Connection

return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True))
return None

if conn := _get_conn():
key = f"AIRFLOW_CONN_{conn_id.upper()}"
old = os.getenv(key)

os.environ[key] = conn.get_uri()

try:
yield
finally:
if old is None:
del os.environ[key]
else:
os.environ[key] = old


def _configure_logging(log_path: str, client: Client) -> tuple[FilteringBoundLogger, BinaryIO | TextIO]:
# If we are told to write logs to a file, redirect the task logger to it. Make sure we append to the
# file though, otherwise when we resume we would lose the logs from the start->deferral segment if it
# lands on the same node as before.
from airflow.sdk.log import init_log_file, logging_processors

log_file_descriptor: BinaryIO | TextIO | None = None

log_file = init_log_file(log_path)

pretty_logs = False
if pretty_logs:
log_file_descriptor = log_file.open("a", buffering=1)
underlying_logger: WrappedLogger = structlog.WriteLogger(cast("TextIO", log_file_descriptor))
else:
log_file_descriptor = log_file.open("ab")
underlying_logger = structlog.BytesLogger(cast("BinaryIO", log_file_descriptor))

with _remote_logging_conn(client):
processors = logging_processors(enable_pretty_log=pretty_logs)[0]
logger = structlog.wrap_logger(underlying_logger, processors=processors, logger_name="task").bind()

return logger, log_file_descriptor


def supervise(
*,
ti: TaskInstance,
Expand Down Expand Up @@ -1690,22 +1779,7 @@ def supervise(
logger: FilteringBoundLogger | None = None
log_file_descriptor: BinaryIO | TextIO | None = None
if log_path:
# If we are told to write logs to a file, redirect the task logger to it. Make sure we append to the
# file though, otherwise when we resume we would lose the logs from the start->deferral segment if it
# lands on the same node as before.
from airflow.sdk.log import init_log_file, logging_processors

log_file = init_log_file(log_path)

pretty_logs = False
if pretty_logs:
log_file_descriptor = log_file.open("a", buffering=1)
underlying_logger: WrappedLogger = structlog.WriteLogger(cast("TextIO", log_file_descriptor))
else:
log_file_descriptor = log_file.open("ab")
underlying_logger = structlog.BytesLogger(cast("BinaryIO", log_file_descriptor))
processors = logging_processors(enable_pretty_log=pretty_logs)[0]
logger = structlog.wrap_logger(underlying_logger, processors=processors, logger_name="task").bind()
logger, log_file_descriptor = _configure_logging(log_path, client)

backends = ensure_secrets_backend_loaded()
log.info(
Expand Down
10 changes: 10 additions & 0 deletions task-sdk/src/airflow/sdk/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,16 @@ def load_remote_log_handler() -> RemoteLogIO | None:
return airflow.logging_config.REMOTE_TASK_LOG


def load_remote_conn_id() -> str | None:
import airflow.logging_config
from airflow.configuration import conf

if conn_id := conf.get("logging", "remote_log_conn_id", fallback=None):
return conn_id

return airflow.logging_config.DEFAULT_REMOTE_CONN_ID


def relative_path_from_logger(logger) -> Path | None:
if not logger:
return None
Expand Down
49 changes: 49 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,15 @@
ActivitySubprocess,
InProcessSupervisorComms,
InProcessTestSupervisor,
_remote_logging_conn,
set_supervisor_comms,
supervise,
)
from airflow.sdk.execution_time.task_runner import run
from airflow.utils import timezone, timezone as tz

from tests_common.test_utils.config import conf_vars

if TYPE_CHECKING:
import kgb

Expand Down Expand Up @@ -2091,3 +2094,49 @@ def _handle_request(self, msg, log, req_id):
# Ensure we got back what we expect
assert isinstance(response, VariableResult)
assert response.value == "value"


@pytest.mark.parametrize(
("remote_logging", "remote_conn", "expected_env"),
(
pytest.param(True, "", "AIRFLOW_CONN_AWS_DEFAULT", id="no-conn-id"),
pytest.param(True, "aws_default", "AIRFLOW_CONN_AWS_DEFAULT", id="explicit-default"),
pytest.param(True, "my_aws", "AIRFLOW_CONN_MY_AWS", id="other"),
pytest.param(False, "", "", id="no-remote-logging"),
),
)
def test_remote_logging_conn(remote_logging, remote_conn, expected_env, monkeypatch):
# This doesn't strictly need the AWS provider, but it does need something that
# airflow.config_templates.airflow_local_settings.DEFAULT_LOGGING_CONFIG knows about
pytest.importorskip("airflow.providers.amazon", reason="'amazon' provider not installed")

# This test is a little bit overly specific to how the logging is currently configured :/
monkeypatch.delitem(sys.modules, "airflow.logging_config")
monkeypatch.delitem(sys.modules, "airflow.config_templates.airflow_local_settings", raising=False)

def handle_request(request: httpx.Request) -> httpx.Response:
return httpx.Response(
status_code=200,
json={
# Minimal enough to pass validation, we don't care what fields are in here for the tests
"conn_id": remote_conn,
"conn_type": "aws",
},
)

with conf_vars(
{
("logging", "remote_logging"): str(remote_logging),
("logging", "remote_base_log_folder"): "cloudwatch://arn:aws:logs:::log-group:test",
("logging", "remote_log_conn_id"): remote_conn,
}
):
env = os.environ.copy()
client = make_client(transport=httpx.MockTransport(handle_request))

with _remote_logging_conn(client):
new_keys = os.environ.keys() - env.keys()
if remote_logging:
assert new_keys == {expected_env}
else:
assert not new_keys