Skip to content

Commit 9459eaa

Browse files
ashbYour friendly bot
authored andcommitted
[v3-0-test] Allow Remote logging providers to load connections from the API Server (#53719)
Often remote logging is down using automatic instance profiles, but not always. If you tried to configure a logger by a connection defined in the metadata DB it would have not worked (it either caused the supervise job to fail early, or to just behave as if the connection didn't exist, depending on the hook's behaviour) Unfortunately, the way of knowing what the default connection ID various hooks use is not easily discoverable, at least not easily from the outside (we can't look at `remote.hook` as for most log providers that would try to load the connection, failing in the way we are trying to fix) so I updated the log config module to keep track of what the default conn id is for the modern log providers. Once we have the connection ID we know (or at least have a good idea that we've got the right one) we then pre-emptively check the secrets backends for it, if not found there load it from the API server, and then either way. if we find a connection we put it in the env variable so that it is available. The reason we use this approach, is that are running in the supervisor process itself, so SUPERVISOR_COMMS is not and cannot be set yet. (cherry picked from commit e4fb686) Co-authored-by: Ash Berlin-Taylor <[email protected]>
1 parent ba1968f commit 9459eaa

File tree

5 files changed

+184
-20
lines changed

5 files changed

+184
-20
lines changed

airflow-core/src/airflow/config_templates/airflow_local_settings.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,27 @@
128128

129129
REMOTE_LOGGING: bool = conf.getboolean("logging", "remote_logging")
130130
REMOTE_TASK_LOG: RemoteLogIO | None = None
131+
DEFAULT_REMOTE_CONN_ID: str | None = None
132+
133+
134+
def _default_conn_name_from(mod_path, hook_name):
135+
# Try to set the default conn name from a hook, but don't error if something goes wrong at runtime
136+
from importlib import import_module
137+
138+
global DEFAULT_REMOTE_CONN_ID
139+
140+
try:
141+
mod = import_module(mod_path)
142+
143+
hook = getattr(mod, hook_name)
144+
145+
DEFAULT_REMOTE_CONN_ID = getattr(hook, "default_conn_name")
146+
except Exception:
147+
# Lets error in tests though!
148+
if "PYTEST_CURRENT_TEST" in os.environ:
149+
raise
150+
return None
151+
131152

132153
if REMOTE_LOGGING:
133154
ELASTICSEARCH_HOST: str | None = conf.get("elasticsearch", "HOST")
@@ -151,6 +172,7 @@
151172
if remote_base_log_folder.startswith("s3://"):
152173
from airflow.providers.amazon.aws.log.s3_task_handler import S3RemoteLogIO
153174

175+
_default_conn_name_from("airflow.providers.amazon.aws.hooks.s3", "S3Hook")
154176
REMOTE_TASK_LOG = S3RemoteLogIO(
155177
**(
156178
{
@@ -166,6 +188,7 @@
166188
elif remote_base_log_folder.startswith("cloudwatch://"):
167189
from airflow.providers.amazon.aws.log.cloudwatch_task_handler import CloudWatchRemoteLogIO
168190

191+
_default_conn_name_from("airflow.providers.amazon.aws.hooks.logs", "AwsLogsHook")
169192
url_parts = urlsplit(remote_base_log_folder)
170193
REMOTE_TASK_LOG = CloudWatchRemoteLogIO(
171194
**(
@@ -182,6 +205,7 @@
182205
elif remote_base_log_folder.startswith("gs://"):
183206
from airflow.providers.google.cloud.log.gcs_task_handler import GCSRemoteLogIO
184207

208+
_default_conn_name_from("airflow.providers.google.cloud.hooks.gcs", "GCSHook")
185209
key_path = conf.get_mandatory_value("logging", "google_key_path", fallback=None)
186210

187211
REMOTE_TASK_LOG = GCSRemoteLogIO(
@@ -199,6 +223,7 @@
199223
elif remote_base_log_folder.startswith("wasb"):
200224
from airflow.providers.microsoft.azure.log.wasb_task_handler import WasbRemoteLogIO
201225

226+
_default_conn_name_from("airflow.providers.microsoft.azure.hooks.wasb", "WasbHook")
202227
wasb_log_container = conf.get_mandatory_value(
203228
"azure_remote_logging", "remote_wasb_log_container", fallback="airflow-logs"
204229
)
@@ -232,6 +257,8 @@
232257
elif remote_base_log_folder.startswith("oss://"):
233258
from airflow.providers.alibaba.cloud.log.oss_task_handler import OSSRemoteLogIO
234259

260+
_default_conn_name_from("airflow.providers.alibaba.cloud.hooks.oss", "OSSHook")
261+
235262
REMOTE_TASK_LOG = OSSRemoteLogIO(
236263
**(
237264
{
@@ -246,6 +273,8 @@
246273
elif remote_base_log_folder.startswith("hdfs://"):
247274
from airflow.providers.apache.hdfs.log.hdfs_task_handler import HdfsRemoteLogIO
248275

276+
_default_conn_name_from("airflow.providers.apache.hdfs.hooks.webhdfs", "WebHDFSHook")
277+
249278
REMOTE_TASK_LOG = HdfsRemoteLogIO(
250279
**(
251280
{

airflow-core/src/airflow/logging_config.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434

3535
REMOTE_TASK_LOG: RemoteLogIO | None
36+
DEFAULT_REMOTE_CONN_ID: str | None = None
3637

3738

3839
def __getattr__(name: str):
@@ -44,7 +45,7 @@ def __getattr__(name: str):
4445

4546
def load_logging_config() -> tuple[dict[str, Any], str]:
4647
"""Configure & Validate Airflow Logging."""
47-
global REMOTE_TASK_LOG
48+
global REMOTE_TASK_LOG, DEFAULT_REMOTE_CONN_ID
4849
fallback = "airflow.config_templates.airflow_local_settings.DEFAULT_LOGGING_CONFIG"
4950
logging_class_path = conf.get("logging", "logging_config_class", fallback=fallback)
5051

@@ -70,10 +71,11 @@ def load_logging_config() -> tuple[dict[str, Any], str]:
7071
f"to: {type(err).__name__}:{err}"
7172
)
7273
else:
73-
mod = logging_class_path.rsplit(".", 1)[0]
74+
modpath = logging_class_path.rsplit(".", 1)[0]
7475
try:
75-
remote_task_log = import_string(f"{mod}.REMOTE_TASK_LOG")
76-
REMOTE_TASK_LOG = remote_task_log
76+
mod = import_string(modpath)
77+
REMOTE_TASK_LOG = getattr(mod, "REMOTE_TASK_LOG")
78+
DEFAULT_REMOTE_CONN_ID = getattr(mod, "DEFAULT_REMOTE_CONN_ID", None)
7779
except Exception as err:
7880
log.info("Remote task logs will not be available due to an error: %s", err)
7981

task-sdk/src/airflow/sdk/execution_time/supervisor.py

Lines changed: 90 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from __future__ import annotations
2121

2222
import atexit
23+
import contextlib
2324
import io
2425
import logging
2526
import os
@@ -127,6 +128,7 @@
127128
from structlog.typing import FilteringBoundLogger, WrappedLogger
128129

129130
from airflow.executors.workloads import BundleInfo
131+
from airflow.sdk.definitions.connection import Connection
130132
from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
131133
from airflow.secrets import BaseSecretsBackend
132134
from airflow.typing_compat import Self
@@ -1615,6 +1617,93 @@ def ensure_secrets_backend_loaded() -> list[BaseSecretsBackend]:
16151617
return backends
16161618

16171619

1620+
@contextlib.contextmanager
1621+
def _remote_logging_conn(client: Client):
1622+
"""
1623+
Pre-fetch the needed remote logging connection.
1624+
1625+
If a remote logger is in use, and has the logging/remote_logging option set, we try to fetch the
1626+
connection it needs, now, directly from the API client, and store it in an env var, so that when the logging
1627+
hook tries to get the connection it
1628+
can find it easily from the env vars
1629+
1630+
This is needed as the BaseHook.get_connection looks for SUPERVISOR_COMMS, but we are still in the
1631+
supervisor process when this is needed, so that doesn't exist yet.
1632+
"""
1633+
from airflow.sdk.log import load_remote_conn_id, load_remote_log_handler
1634+
1635+
if load_remote_log_handler() is None or not (conn_id := load_remote_conn_id()):
1636+
# Nothing to do
1637+
yield
1638+
return
1639+
1640+
# Since we need to use the API Client directly, we can't use Connection.get as that would try to use
1641+
# SUPERVISOR_COMMS
1642+
1643+
# TODO: Store in the SecretsCache if its enabled - see #48858
1644+
1645+
def _get_conn() -> Connection | None:
1646+
backends = ensure_secrets_backend_loaded()
1647+
for secrets_backend in backends:
1648+
try:
1649+
conn = secrets_backend.get_connection(conn_id=conn_id)
1650+
if conn:
1651+
return conn
1652+
except Exception:
1653+
log.exception(
1654+
"Unable to retrieve connection from secrets backend (%s). "
1655+
"Checking subsequent secrets backend.",
1656+
type(secrets_backend).__name__,
1657+
)
1658+
1659+
conn = client.connections.get(conn_id)
1660+
if isinstance(conn, ConnectionResponse):
1661+
conn_result = ConnectionResult.from_conn_response(conn)
1662+
from airflow.sdk.definitions.connection import Connection
1663+
1664+
return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True))
1665+
return None
1666+
1667+
if conn := _get_conn():
1668+
key = f"AIRFLOW_CONN_{conn_id.upper()}"
1669+
old = os.getenv(key)
1670+
1671+
os.environ[key] = conn.get_uri()
1672+
1673+
try:
1674+
yield
1675+
finally:
1676+
if old is None:
1677+
del os.environ[key]
1678+
else:
1679+
os.environ[key] = old
1680+
1681+
1682+
def _configure_logging(log_path: str, client: Client) -> tuple[FilteringBoundLogger, BinaryIO | TextIO]:
1683+
# If we are told to write logs to a file, redirect the task logger to it. Make sure we append to the
1684+
# file though, otherwise when we resume we would lose the logs from the start->deferral segment if it
1685+
# lands on the same node as before.
1686+
from airflow.sdk.log import init_log_file, logging_processors
1687+
1688+
log_file_descriptor: BinaryIO | TextIO | None = None
1689+
1690+
log_file = init_log_file(log_path)
1691+
1692+
pretty_logs = False
1693+
if pretty_logs:
1694+
log_file_descriptor = log_file.open("a", buffering=1)
1695+
underlying_logger: WrappedLogger = structlog.WriteLogger(cast("TextIO", log_file_descriptor))
1696+
else:
1697+
log_file_descriptor = log_file.open("ab")
1698+
underlying_logger = structlog.BytesLogger(cast("BinaryIO", log_file_descriptor))
1699+
1700+
with _remote_logging_conn(client):
1701+
processors = logging_processors(enable_pretty_log=pretty_logs)[0]
1702+
logger = structlog.wrap_logger(underlying_logger, processors=processors, logger_name="task").bind()
1703+
1704+
return logger, log_file_descriptor
1705+
1706+
16181707
def supervise(
16191708
*,
16201709
ti: TaskInstance,
@@ -1690,22 +1779,7 @@ def supervise(
16901779
logger: FilteringBoundLogger | None = None
16911780
log_file_descriptor: BinaryIO | TextIO | None = None
16921781
if log_path:
1693-
# If we are told to write logs to a file, redirect the task logger to it. Make sure we append to the
1694-
# file though, otherwise when we resume we would lose the logs from the start->deferral segment if it
1695-
# lands on the same node as before.
1696-
from airflow.sdk.log import init_log_file, logging_processors
1697-
1698-
log_file = init_log_file(log_path)
1699-
1700-
pretty_logs = False
1701-
if pretty_logs:
1702-
log_file_descriptor = log_file.open("a", buffering=1)
1703-
underlying_logger: WrappedLogger = structlog.WriteLogger(cast("TextIO", log_file_descriptor))
1704-
else:
1705-
log_file_descriptor = log_file.open("ab")
1706-
underlying_logger = structlog.BytesLogger(cast("BinaryIO", log_file_descriptor))
1707-
processors = logging_processors(enable_pretty_log=pretty_logs)[0]
1708-
logger = structlog.wrap_logger(underlying_logger, processors=processors, logger_name="task").bind()
1782+
logger, log_file_descriptor = _configure_logging(log_path, client)
17091783

17101784
backends = ensure_secrets_backend_loaded()
17111785
log.info(

task-sdk/src/airflow/sdk/log.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,16 @@ def load_remote_log_handler() -> RemoteLogIO | None:
523523
return airflow.logging_config.REMOTE_TASK_LOG
524524

525525

526+
def load_remote_conn_id() -> str | None:
527+
import airflow.logging_config
528+
from airflow.configuration import conf
529+
530+
if conn_id := conf.get("logging", "remote_log_conn_id", fallback=None):
531+
return conn_id
532+
533+
return airflow.logging_config.DEFAULT_REMOTE_CONN_ID
534+
535+
526536
def relative_path_from_logger(logger) -> Path | None:
527537
if not logger:
528538
return None

task-sdk/tests/task_sdk/execution_time/test_supervisor.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,14 @@
110110
ActivitySubprocess,
111111
InProcessSupervisorComms,
112112
InProcessTestSupervisor,
113+
_remote_logging_conn,
113114
set_supervisor_comms,
114115
supervise,
115116
)
116117
from airflow.utils import timezone, timezone as tz
117118

119+
from tests_common.test_utils.config import conf_vars
120+
118121
if TYPE_CHECKING:
119122
import kgb
120123

@@ -2010,3 +2013,49 @@ def _handle_request(self, msg, log, req_id):
20102013
# Ensure we got back what we expect
20112014
assert isinstance(response, VariableResult)
20122015
assert response.value == "value"
2016+
2017+
2018+
@pytest.mark.parametrize(
2019+
("remote_logging", "remote_conn", "expected_env"),
2020+
(
2021+
pytest.param(True, "", "AIRFLOW_CONN_AWS_DEFAULT", id="no-conn-id"),
2022+
pytest.param(True, "aws_default", "AIRFLOW_CONN_AWS_DEFAULT", id="explicit-default"),
2023+
pytest.param(True, "my_aws", "AIRFLOW_CONN_MY_AWS", id="other"),
2024+
pytest.param(False, "", "", id="no-remote-logging"),
2025+
),
2026+
)
2027+
def test_remote_logging_conn(remote_logging, remote_conn, expected_env, monkeypatch):
2028+
# This doesn't strictly need the AWS provider, but it does need something that
2029+
# airflow.config_templates.airflow_local_settings.DEFAULT_LOGGING_CONFIG knows about
2030+
pytest.importorskip("airflow.providers.amazon", reason="'amazon' provider not installed")
2031+
2032+
# This test is a little bit overly specific to how the logging is currently configured :/
2033+
monkeypatch.delitem(sys.modules, "airflow.logging_config")
2034+
monkeypatch.delitem(sys.modules, "airflow.config_templates.airflow_local_settings", raising=False)
2035+
2036+
def handle_request(request: httpx.Request) -> httpx.Response:
2037+
return httpx.Response(
2038+
status_code=200,
2039+
json={
2040+
# Minimal enough to pass validation, we don't care what fields are in here for the tests
2041+
"conn_id": remote_conn,
2042+
"conn_type": "aws",
2043+
},
2044+
)
2045+
2046+
with conf_vars(
2047+
{
2048+
("logging", "remote_logging"): str(remote_logging),
2049+
("logging", "remote_base_log_folder"): "cloudwatch://arn:aws:logs:::log-group:test",
2050+
("logging", "remote_log_conn_id"): remote_conn,
2051+
}
2052+
):
2053+
env = os.environ.copy()
2054+
client = make_client(transport=httpx.MockTransport(handle_request))
2055+
2056+
with _remote_logging_conn(client):
2057+
new_keys = os.environ.keys() - env.keys()
2058+
if remote_logging:
2059+
assert new_keys == {expected_env}
2060+
else:
2061+
assert not new_keys

0 commit comments

Comments
 (0)