Skip to content

Commit 0e35aa9

Browse files
[v3-0-test] Ensure trigger kwargs are properly deserialized during trigger execution (#52693) (#52721)
(cherry picked from commit 64a5919) Co-authored-by: Amogh Desai <[email protected]>
1 parent a982c48 commit 0e35aa9

File tree

3 files changed

+68
-14
lines changed

3 files changed

+68
-14
lines changed

airflow-core/src/airflow/jobs/triggerer_job_runner.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -875,8 +875,16 @@ async def create_triggers(self):
875875
await asyncio.sleep(0)
876876

877877
try:
878-
kwargs = Trigger._decrypt_kwargs(workload.encrypted_kwargs)
879-
trigger_instance = trigger_class(**kwargs)
878+
from airflow.serialization.serialized_objects import smart_decode_trigger_kwargs
879+
880+
# Decrypt and clean trigger kwargs before for execution
881+
# Note: We only clean up serialization artifacts (__var, __type keys) here,
882+
# not in `_decrypt_kwargs` because it is used during hash comparison in
883+
# add_asset_trigger_references and could lead to adverse effects like hash mismatches
884+
# that could cause None values in collections.
885+
kw = Trigger._decrypt_kwargs(workload.encrypted_kwargs)
886+
deserialised_kwargs = {k: smart_decode_trigger_kwargs(v) for k, v in kw.items()}
887+
trigger_instance = trigger_class(**deserialised_kwargs)
880888
except TypeError as err:
881889
self.log.error("Trigger failed to inflate", error=err)
882890
self.failed_triggers.append((trigger_id, err))

airflow-core/src/airflow/serialization/serialized_objects.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -337,19 +337,20 @@ def decode_asset_condition(var: dict[str, Any]) -> BaseAsset:
337337
raise ValueError(f"deserialization not implemented for DAT {dat!r}")
338338

339339

340-
def decode_asset(var: dict[str, Any]):
341-
def _smart_decode_trigger_kwargs(d):
342-
"""
343-
Slightly clean up kwargs for display.
340+
def smart_decode_trigger_kwargs(d):
341+
"""
342+
Slightly clean up kwargs for display or execution.
344343
345-
This detects one level of BaseSerialization and tries to deserialize the
346-
content, removing some __type __var ugliness when the value is displayed
347-
in UI to the user.
348-
"""
349-
if not isinstance(d, dict) or Encoding.TYPE not in d:
350-
return d
351-
return BaseSerialization.deserialize(d)
344+
This detects one level of BaseSerialization and tries to deserialize the
345+
content, removing some __type __var ugliness when the value is displayed
346+
in UI to the user and/or while execution.
347+
"""
348+
if not isinstance(d, dict) or Encoding.TYPE not in d:
349+
return d
350+
return BaseSerialization.deserialize(d)
352351

352+
353+
def decode_asset(var: dict[str, Any]):
353354
watchers = var.get("watchers", [])
354355
return Asset(
355356
name=var["name"],
@@ -361,7 +362,7 @@ def _smart_decode_trigger_kwargs(d):
361362
name=watcher["name"],
362363
trigger={
363364
"classpath": watcher["trigger"]["classpath"],
364-
"kwargs": _smart_decode_trigger_kwargs(watcher["trigger"]["kwargs"]),
365+
"kwargs": smart_decode_trigger_kwargs(watcher["trigger"]["kwargs"]),
365366
},
366367
)
367368
for watcher in watchers

airflow-core/tests/unit/jobs/test_triggerer_job.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,51 @@ async def test_invalid_trigger(self, supervisor_builder):
326326
assert trigger_id == 1
327327
assert traceback[-1] == "ModuleNotFoundError: No module named 'fake'\n"
328328

329+
@pytest.mark.asyncio
330+
async def test_trigger_kwargs_serialization_cleanup(self, session):
331+
"""
332+
Test that trigger kwargs are properly cleaned of serialization artifacts
333+
(__var, __type keys).
334+
"""
335+
from airflow.serialization.serialized_objects import BaseSerialization
336+
337+
kw = {"simple": "test", "tuple": (), "dict": {}, "list": []}
338+
339+
serialized_kwargs = BaseSerialization.serialize(kw)
340+
341+
trigger_orm = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs=serialized_kwargs)
342+
session.add(trigger_orm)
343+
session.commit()
344+
345+
stored_kwargs = trigger_orm.kwargs
346+
assert stored_kwargs == {
347+
"Encoding.TYPE": "dict",
348+
"Encoding.VAR": {
349+
"dict": {"Encoding.TYPE": "dict", "Encoding.VAR": {}},
350+
"list": [],
351+
"simple": "test",
352+
"tuple": {"Encoding.TYPE": "tuple", "Encoding.VAR": []},
353+
},
354+
}
355+
356+
runner = TriggerRunner()
357+
runner.to_create.append(
358+
workloads.RunTrigger.model_construct(
359+
id=trigger_orm.id,
360+
ti=None,
361+
classpath=trigger_orm.classpath,
362+
encrypted_kwargs=trigger_orm.encrypted_kwargs,
363+
)
364+
)
365+
366+
await runner.create_triggers()
367+
assert trigger_orm.id in runner.triggers
368+
trigger_instance = runner.triggers[trigger_orm.id]["task"]
369+
370+
# The test passes if no exceptions were raised during trigger creation
371+
trigger_instance.cancel()
372+
await runner.cleanup_finished_triggers()
373+
329374

330375
@pytest.mark.asyncio
331376
async def test_trigger_create_race_condition_38599(session, supervisor_builder):

0 commit comments

Comments
 (0)