Skip to content

Commit 3dc597f

Browse files
authored
Fix memory leak in dag-processor (apache#50558)
closes apache#50097 closes apache#49887 Previously, each `DagFileProcessorProcess` created its own `InProcessExecutionAPI` client instance, leading to unnecessary thread creation and resource use. This commit ensures that a single `Client` backed by `InProcessExecutionAPI` is created and owned by `DagFileProcessorManager`, and passed into all DAG file processor subprocesses.
1 parent 7c29214 commit 3dc597f

File tree

4 files changed

+80
-24
lines changed

4 files changed

+80
-24
lines changed

airflow-core/src/airflow/dag_processing/manager.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from uuid6 import uuid7
4949

5050
import airflow.models
51+
from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI
5152
from airflow.configuration import conf
5253
from airflow.dag_processing.bundles.manager import DagBundlesManager
5354
from airflow.dag_processing.collection import update_dag_parsing_results_in_db
@@ -80,6 +81,7 @@
8081

8182
from airflow.callbacks.callback_requests import CallbackRequest
8283
from airflow.dag_processing.bundles.base import BaseDagBundle
84+
from airflow.sdk.api.client import Client
8385

8486

8587
class DagParsingStat(NamedTuple):
@@ -213,6 +215,9 @@ class DagFileProcessorManager(LoggingMixin):
213215
_force_refresh_bundles: set[str] = attrs.field(factory=set, init=False)
214216
"""List of bundles that need to be force refreshed in the next loop"""
215217

218+
_api_server: InProcessExecutionAPI = attrs.field(init=False, factory=InProcessExecutionAPI)
219+
"""API server to interact with Metadata DB"""
220+
216221
def register_exit_signals(self):
217222
"""Register signals that stop child processes."""
218223
signal.signal(signal.SIGINT, self._exit_gracefully)
@@ -867,6 +872,15 @@ def _get_logger_for_dag_file(self, dag_file: DagFileInfo):
867872
underlying_logger, processors=processors, logger_name="processor"
868873
).bind(), logger_filehandle
869874

875+
@functools.cached_property
876+
def client(self) -> Client:
877+
from airflow.sdk.api.client import Client
878+
879+
client = Client(base_url=None, token="", dry_run=True, transport=self._api_server.transport)
880+
# Mypy is wrong -- the setter accepts a string on the property setter! `URLType = URL | str`
881+
client.base_url = "http://in-process.invalid./" # type: ignore[assignment]
882+
return client
883+
870884
def _create_process(self, dag_file: DagFileInfo) -> DagFileProcessorProcess:
871885
id = uuid7()
872886

@@ -881,6 +895,7 @@ def _create_process(self, dag_file: DagFileInfo) -> DagFileProcessorProcess:
881895
selector=self.selector,
882896
logger=logger,
883897
logger_filehandle=logger_filehandle,
898+
client=self.client,
884899
)
885900

886901
def _start_new_processes(self):

airflow-core/src/airflow/dag_processing/processor.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19-
import functools
2019
import os
2120
import sys
2221
import traceback
@@ -239,6 +238,9 @@ class DagFileProcessorProcess(WatchedSubprocess):
239238
parsing_result: DagFileParsingResult | None = None
240239
decoder: ClassVar[TypeAdapter[ToManager]] = TypeAdapter[ToManager](ToManager)
241240

241+
client: Client
242+
"""The HTTP client to use for communication with the API server."""
243+
242244
@classmethod
243245
def start( # type: ignore[override]
244246
cls,
@@ -247,9 +249,10 @@ def start( # type: ignore[override]
247249
bundle_path: Path,
248250
callbacks: list[CallbackRequest],
249251
target: Callable[[], None] = _parse_file_entrypoint,
252+
client: Client,
250253
**kwargs,
251254
) -> Self:
252-
proc: Self = super().start(target=target, **kwargs)
255+
proc: Self = super().start(target=target, client=client, **kwargs)
253256
proc._on_child_started(callbacks, path, bundle_path)
254257
return proc
255258

@@ -267,15 +270,6 @@ def _on_child_started(
267270
)
268271
self.send_msg(msg)
269272

270-
@functools.cached_property
271-
def client(self) -> Client:
272-
from airflow.sdk.api.client import Client
273-
274-
client = Client(base_url=None, token="", dry_run=True, transport=in_process_api_server().transport)
275-
# Mypy is wrong -- the setter accepts a string on the property setter! `URLType = URL | str`
276-
client.base_url = "http://in-process.invalid./" # type: ignore[assignment]
277-
return client
278-
279273
def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> None: # type: ignore[override]
280274
from airflow.sdk.api.datamodels._generated import ConnectionResponse, VariableResponse
281275

airflow-core/tests/unit/dag_processing/test_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def mock_processor(self, start_time: float | None = None) -> tuple[DagFileProces
147147
stdin=write_end,
148148
requests_fd=123,
149149
logger_filehandle=logger_filehandle,
150+
client=MagicMock(),
150151
)
151152
if start_time:
152153
ret.start_time = start_time
@@ -899,6 +900,7 @@ def test_callback_queue(self, mock_get_logger, configure_testing_dag_bundle):
899900
selector=mock.ANY,
900901
logger=mock_logger,
901902
logger_filehandle=mock_filehandle,
903+
client=mock.ANY,
902904
),
903905
mock.call(
904906
id=mock.ANY,
@@ -908,6 +910,7 @@ def test_callback_queue(self, mock_get_logger, configure_testing_dag_bundle):
908910
selector=mock.ANY,
909911
logger=mock_logger,
910912
logger_filehandle=mock_filehandle,
913+
client=mock.ANY,
911914
),
912915
]
913916
# And removed from the queue

airflow-core/tests/unit/dag_processing/test_processor.py

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import structlog
3030
from pydantic import TypeAdapter
3131

32+
from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI
3233
from airflow.callbacks.callback_requests import CallbackRequest, DagCallbackRequest, TaskCallbackRequest
3334
from airflow.configuration import conf
3435
from airflow.dag_processing.processor import (
@@ -40,6 +41,7 @@
4041
from airflow.models import DagBag, TaskInstance
4142
from airflow.models.baseoperator import BaseOperator
4243
from airflow.models.serialized_dag import SerializedDagModel
44+
from airflow.sdk.api.client import Client
4345
from airflow.sdk.execution_time.task_runner import CommsDecoder
4446
from airflow.utils import timezone
4547
from airflow.utils.session import create_session
@@ -67,6 +69,15 @@ def disable_load_example():
6769
yield
6870

6971

72+
@pytest.fixture
73+
def inprocess_client():
74+
"""Provides an in-process Client backed by a single API server."""
75+
api = InProcessExecutionAPI()
76+
client = Client(base_url=None, token="", dry_run=True, transport=api.transport)
77+
client.base_url = "http://in-process.invalid/" # type: ignore[assignment]
78+
return client
79+
80+
7081
@pytest.mark.usefixtures("disable_load_example")
7182
class TestDagFileProcessor:
7283
def _process_file(
@@ -130,7 +141,7 @@ def fake_collect_dags(dagbag: DagBag, *args, **kwargs):
130141
assert "a.py" in resp.import_errors
131142

132143
def test_top_level_variable_access(
133-
self, spy_agency: SpyAgency, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch
144+
self, spy_agency: SpyAgency, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch, inprocess_client
134145
):
135146
logger_filehandle = MagicMock()
136147

@@ -144,7 +155,12 @@ def dag_in_a_fn():
144155

145156
monkeypatch.setenv("AIRFLOW_VAR_MYVAR", "abc")
146157
proc = DagFileProcessorProcess.start(
147-
id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle
158+
id=1,
159+
path=path,
160+
bundle_path=tmp_path,
161+
callbacks=[],
162+
logger_filehandle=logger_filehandle,
163+
client=inprocess_client,
148164
)
149165

150166
while not proc.is_ready:
@@ -156,7 +172,7 @@ def dag_in_a_fn():
156172
assert result.serialized_dags[0].dag_id == "test_abc"
157173

158174
def test_top_level_variable_access_not_found(
159-
self, spy_agency: SpyAgency, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch
175+
self, spy_agency: SpyAgency, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch, inprocess_client
160176
):
161177
logger_filehandle = MagicMock()
162178

@@ -168,7 +184,12 @@ def dag_in_a_fn():
168184

169185
path = write_dag_in_a_fn_to_file(dag_in_a_fn, tmp_path)
170186
proc = DagFileProcessorProcess.start(
171-
id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle
187+
id=1,
188+
path=path,
189+
bundle_path=tmp_path,
190+
callbacks=[],
191+
logger_filehandle=logger_filehandle,
192+
client=inprocess_client,
172193
)
173194

174195
while not proc.is_ready:
@@ -180,7 +201,7 @@ def dag_in_a_fn():
180201
if result.import_errors:
181202
assert "VARIABLE_NOT_FOUND" in next(iter(result.import_errors.values()))
182203

183-
def test_top_level_variable_set(self, tmp_path: pathlib.Path):
204+
def test_top_level_variable_set(self, tmp_path: pathlib.Path, inprocess_client):
184205
from airflow.models.variable import Variable as VariableORM
185206

186207
logger_filehandle = MagicMock()
@@ -194,7 +215,12 @@ def dag_in_a_fn():
194215

195216
path = write_dag_in_a_fn_to_file(dag_in_a_fn, tmp_path)
196217
proc = DagFileProcessorProcess.start(
197-
id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle
218+
id=1,
219+
path=path,
220+
bundle_path=tmp_path,
221+
callbacks=[],
222+
logger_filehandle=logger_filehandle,
223+
client=inprocess_client,
198224
)
199225

200226
while not proc.is_ready:
@@ -210,7 +236,7 @@ def dag_in_a_fn():
210236
assert len(all_vars) == 1
211237
assert all_vars[0].key == "mykey"
212238

213-
def test_top_level_variable_delete(self, tmp_path: pathlib.Path):
239+
def test_top_level_variable_delete(self, tmp_path: pathlib.Path, inprocess_client):
214240
from airflow.models.variable import Variable as VariableORM
215241

216242
logger_filehandle = MagicMock()
@@ -230,7 +256,12 @@ def dag_in_a_fn():
230256

231257
path = write_dag_in_a_fn_to_file(dag_in_a_fn, tmp_path)
232258
proc = DagFileProcessorProcess.start(
233-
id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle
259+
id=1,
260+
path=path,
261+
bundle_path=tmp_path,
262+
callbacks=[],
263+
logger_filehandle=logger_filehandle,
264+
client=inprocess_client,
234265
)
235266

236267
while not proc.is_ready:
@@ -245,7 +276,9 @@ def dag_in_a_fn():
245276
all_vars = session.query(VariableORM).all()
246277
assert len(all_vars) == 0
247278

248-
def test_top_level_connection_access(self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch):
279+
def test_top_level_connection_access(
280+
self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch, inprocess_client
281+
):
249282
logger_filehandle = MagicMock()
250283

251284
def dag_in_a_fn():
@@ -259,7 +292,12 @@ def dag_in_a_fn():
259292

260293
monkeypatch.setenv("AIRFLOW_CONN_MY_CONN", '{"conn_type": "aws"}')
261294
proc = DagFileProcessorProcess.start(
262-
id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle
295+
id=1,
296+
path=path,
297+
bundle_path=tmp_path,
298+
callbacks=[],
299+
logger_filehandle=logger_filehandle,
300+
client=inprocess_client,
263301
)
264302

265303
while not proc.is_ready:
@@ -270,7 +308,7 @@ def dag_in_a_fn():
270308
assert result.import_errors == {}
271309
assert result.serialized_dags[0].dag_id == "test_my_conn"
272310

273-
def test_top_level_connection_access_not_found(self, tmp_path: pathlib.Path):
311+
def test_top_level_connection_access_not_found(self, tmp_path: pathlib.Path, inprocess_client):
274312
logger_filehandle = MagicMock()
275313

276314
def dag_in_a_fn():
@@ -282,7 +320,12 @@ def dag_in_a_fn():
282320

283321
path = write_dag_in_a_fn_to_file(dag_in_a_fn, tmp_path)
284322
proc = DagFileProcessorProcess.start(
285-
id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle
323+
id=1,
324+
path=path,
325+
bundle_path=tmp_path,
326+
callbacks=[],
327+
logger_filehandle=logger_filehandle,
328+
client=inprocess_client,
286329
)
287330

288331
while not proc.is_ready:
@@ -294,7 +337,7 @@ def dag_in_a_fn():
294337
if result.import_errors:
295338
assert "CONNECTION_NOT_FOUND" in next(iter(result.import_errors.values()))
296339

297-
def test_import_module_in_bundle_root(self, tmp_path: pathlib.Path):
340+
def test_import_module_in_bundle_root(self, tmp_path: pathlib.Path, inprocess_client):
298341
tmp_path.joinpath("util.py").write_text("NAME = 'dag_name'")
299342

300343
dag1_path = tmp_path.joinpath("dag1.py")
@@ -314,6 +357,7 @@ def test_import_module_in_bundle_root(self, tmp_path: pathlib.Path):
314357
bundle_path=tmp_path,
315358
callbacks=[],
316359
logger_filehandle=MagicMock(),
360+
client=inprocess_client,
317361
)
318362
while not proc.is_ready:
319363
proc._service_subprocess(0.1)

0 commit comments

Comments
 (0)