29
29
import structlog
30
30
from pydantic import TypeAdapter
31
31
32
+ from airflow .api_fastapi .execution_api .app import InProcessExecutionAPI
32
33
from airflow .callbacks .callback_requests import CallbackRequest , DagCallbackRequest , TaskCallbackRequest
33
34
from airflow .configuration import conf
34
35
from airflow .dag_processing .processor import (
40
41
from airflow .models import DagBag , TaskInstance
41
42
from airflow .models .baseoperator import BaseOperator
42
43
from airflow .models .serialized_dag import SerializedDagModel
44
+ from airflow .sdk .api .client import Client
43
45
from airflow .sdk .execution_time .task_runner import CommsDecoder
44
46
from airflow .utils import timezone
45
47
from airflow .utils .session import create_session
@@ -67,6 +69,15 @@ def disable_load_example():
67
69
yield
68
70
69
71
72
+ @pytest .fixture
73
+ def inprocess_client ():
74
+ """Provides an in-process Client backed by a single API server."""
75
+ api = InProcessExecutionAPI ()
76
+ client = Client (base_url = None , token = "" , dry_run = True , transport = api .transport )
77
+ client .base_url = "http://in-process.invalid/" # type: ignore[assignment]
78
+ return client
79
+
80
+
70
81
@pytest .mark .usefixtures ("disable_load_example" )
71
82
class TestDagFileProcessor :
72
83
def _process_file (
@@ -130,7 +141,7 @@ def fake_collect_dags(dagbag: DagBag, *args, **kwargs):
130
141
assert "a.py" in resp .import_errors
131
142
132
143
def test_top_level_variable_access (
133
- self , spy_agency : SpyAgency , tmp_path : pathlib .Path , monkeypatch : pytest .MonkeyPatch
144
+ self , spy_agency : SpyAgency , tmp_path : pathlib .Path , monkeypatch : pytest .MonkeyPatch , inprocess_client
134
145
):
135
146
logger_filehandle = MagicMock ()
136
147
@@ -144,7 +155,12 @@ def dag_in_a_fn():
144
155
145
156
monkeypatch .setenv ("AIRFLOW_VAR_MYVAR" , "abc" )
146
157
proc = DagFileProcessorProcess .start (
147
- id = 1 , path = path , bundle_path = tmp_path , callbacks = [], logger_filehandle = logger_filehandle
158
+ id = 1 ,
159
+ path = path ,
160
+ bundle_path = tmp_path ,
161
+ callbacks = [],
162
+ logger_filehandle = logger_filehandle ,
163
+ client = inprocess_client ,
148
164
)
149
165
150
166
while not proc .is_ready :
@@ -156,7 +172,7 @@ def dag_in_a_fn():
156
172
assert result .serialized_dags [0 ].dag_id == "test_abc"
157
173
158
174
def test_top_level_variable_access_not_found (
159
- self , spy_agency : SpyAgency , tmp_path : pathlib .Path , monkeypatch : pytest .MonkeyPatch
175
+ self , spy_agency : SpyAgency , tmp_path : pathlib .Path , monkeypatch : pytest .MonkeyPatch , inprocess_client
160
176
):
161
177
logger_filehandle = MagicMock ()
162
178
@@ -168,7 +184,12 @@ def dag_in_a_fn():
168
184
169
185
path = write_dag_in_a_fn_to_file (dag_in_a_fn , tmp_path )
170
186
proc = DagFileProcessorProcess .start (
171
- id = 1 , path = path , bundle_path = tmp_path , callbacks = [], logger_filehandle = logger_filehandle
187
+ id = 1 ,
188
+ path = path ,
189
+ bundle_path = tmp_path ,
190
+ callbacks = [],
191
+ logger_filehandle = logger_filehandle ,
192
+ client = inprocess_client ,
172
193
)
173
194
174
195
while not proc .is_ready :
@@ -180,7 +201,7 @@ def dag_in_a_fn():
180
201
if result .import_errors :
181
202
assert "VARIABLE_NOT_FOUND" in next (iter (result .import_errors .values ()))
182
203
183
- def test_top_level_variable_set (self , tmp_path : pathlib .Path ):
204
+ def test_top_level_variable_set (self , tmp_path : pathlib .Path , inprocess_client ):
184
205
from airflow .models .variable import Variable as VariableORM
185
206
186
207
logger_filehandle = MagicMock ()
@@ -194,7 +215,12 @@ def dag_in_a_fn():
194
215
195
216
path = write_dag_in_a_fn_to_file (dag_in_a_fn , tmp_path )
196
217
proc = DagFileProcessorProcess .start (
197
- id = 1 , path = path , bundle_path = tmp_path , callbacks = [], logger_filehandle = logger_filehandle
218
+ id = 1 ,
219
+ path = path ,
220
+ bundle_path = tmp_path ,
221
+ callbacks = [],
222
+ logger_filehandle = logger_filehandle ,
223
+ client = inprocess_client ,
198
224
)
199
225
200
226
while not proc .is_ready :
@@ -210,7 +236,7 @@ def dag_in_a_fn():
210
236
assert len (all_vars ) == 1
211
237
assert all_vars [0 ].key == "mykey"
212
238
213
- def test_top_level_variable_delete (self , tmp_path : pathlib .Path ):
239
+ def test_top_level_variable_delete (self , tmp_path : pathlib .Path , inprocess_client ):
214
240
from airflow .models .variable import Variable as VariableORM
215
241
216
242
logger_filehandle = MagicMock ()
@@ -230,7 +256,12 @@ def dag_in_a_fn():
230
256
231
257
path = write_dag_in_a_fn_to_file (dag_in_a_fn , tmp_path )
232
258
proc = DagFileProcessorProcess .start (
233
- id = 1 , path = path , bundle_path = tmp_path , callbacks = [], logger_filehandle = logger_filehandle
259
+ id = 1 ,
260
+ path = path ,
261
+ bundle_path = tmp_path ,
262
+ callbacks = [],
263
+ logger_filehandle = logger_filehandle ,
264
+ client = inprocess_client ,
234
265
)
235
266
236
267
while not proc .is_ready :
@@ -245,7 +276,9 @@ def dag_in_a_fn():
245
276
all_vars = session .query (VariableORM ).all ()
246
277
assert len (all_vars ) == 0
247
278
248
- def test_top_level_connection_access (self , tmp_path : pathlib .Path , monkeypatch : pytest .MonkeyPatch ):
279
+ def test_top_level_connection_access (
280
+ self , tmp_path : pathlib .Path , monkeypatch : pytest .MonkeyPatch , inprocess_client
281
+ ):
249
282
logger_filehandle = MagicMock ()
250
283
251
284
def dag_in_a_fn ():
@@ -259,7 +292,12 @@ def dag_in_a_fn():
259
292
260
293
monkeypatch .setenv ("AIRFLOW_CONN_MY_CONN" , '{"conn_type": "aws"}' )
261
294
proc = DagFileProcessorProcess .start (
262
- id = 1 , path = path , bundle_path = tmp_path , callbacks = [], logger_filehandle = logger_filehandle
295
+ id = 1 ,
296
+ path = path ,
297
+ bundle_path = tmp_path ,
298
+ callbacks = [],
299
+ logger_filehandle = logger_filehandle ,
300
+ client = inprocess_client ,
263
301
)
264
302
265
303
while not proc .is_ready :
@@ -270,7 +308,7 @@ def dag_in_a_fn():
270
308
assert result .import_errors == {}
271
309
assert result .serialized_dags [0 ].dag_id == "test_my_conn"
272
310
273
- def test_top_level_connection_access_not_found (self , tmp_path : pathlib .Path ):
311
+ def test_top_level_connection_access_not_found (self , tmp_path : pathlib .Path , inprocess_client ):
274
312
logger_filehandle = MagicMock ()
275
313
276
314
def dag_in_a_fn ():
@@ -282,7 +320,12 @@ def dag_in_a_fn():
282
320
283
321
path = write_dag_in_a_fn_to_file (dag_in_a_fn , tmp_path )
284
322
proc = DagFileProcessorProcess .start (
285
- id = 1 , path = path , bundle_path = tmp_path , callbacks = [], logger_filehandle = logger_filehandle
323
+ id = 1 ,
324
+ path = path ,
325
+ bundle_path = tmp_path ,
326
+ callbacks = [],
327
+ logger_filehandle = logger_filehandle ,
328
+ client = inprocess_client ,
286
329
)
287
330
288
331
while not proc .is_ready :
@@ -294,7 +337,7 @@ def dag_in_a_fn():
294
337
if result .import_errors :
295
338
assert "CONNECTION_NOT_FOUND" in next (iter (result .import_errors .values ()))
296
339
297
- def test_import_module_in_bundle_root (self , tmp_path : pathlib .Path ):
340
+ def test_import_module_in_bundle_root (self , tmp_path : pathlib .Path , inprocess_client ):
298
341
tmp_path .joinpath ("util.py" ).write_text ("NAME = 'dag_name'" )
299
342
300
343
dag1_path = tmp_path .joinpath ("dag1.py" )
@@ -314,6 +357,7 @@ def test_import_module_in_bundle_root(self, tmp_path: pathlib.Path):
314
357
bundle_path = tmp_path ,
315
358
callbacks = [],
316
359
logger_filehandle = MagicMock (),
360
+ client = inprocess_client ,
317
361
)
318
362
while not proc .is_ready :
319
363
proc ._service_subprocess (0.1 )
0 commit comments