Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions airflow-core/src/airflow/jobs/triggerer_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,8 +875,16 @@ async def create_triggers(self):
await asyncio.sleep(0)

try:
kwargs = Trigger._decrypt_kwargs(workload.encrypted_kwargs)
trigger_instance = trigger_class(**kwargs)
from airflow.serialization.serialized_objects import smart_decode_trigger_kwargs

# Decrypt and clean trigger kwargs before for execution
# Note: We only clean up serialization artifacts (__var, __type keys) here,
# not in `_decrypt_kwargs` because it is used during hash comparison in
# add_asset_trigger_references and could lead to adverse effects like hash mismatches
# that could cause None values in collections.
kw = Trigger._decrypt_kwargs(workload.encrypted_kwargs)
deserialised_kwargs = {k: smart_decode_trigger_kwargs(v) for k, v in kw.items()}
trigger_instance = trigger_class(**deserialised_kwargs)
except TypeError as err:
self.log.error("Trigger failed to inflate", error=err)
self.failed_triggers.append((trigger_id, err))
Expand Down
25 changes: 13 additions & 12 deletions airflow-core/src/airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,19 +337,20 @@ def decode_asset_condition(var: dict[str, Any]) -> BaseAsset:
raise ValueError(f"deserialization not implemented for DAT {dat!r}")


def decode_asset(var: dict[str, Any]):
def _smart_decode_trigger_kwargs(d):
"""
Slightly clean up kwargs for display.
def smart_decode_trigger_kwargs(d):
"""
Slightly clean up kwargs for display or execution.

This detects one level of BaseSerialization and tries to deserialize the
content, removing some __type __var ugliness when the value is displayed
in UI to the user.
"""
if not isinstance(d, dict) or Encoding.TYPE not in d:
return d
return BaseSerialization.deserialize(d)
This detects one level of BaseSerialization and tries to deserialize the
content, removing some __type __var ugliness when the value is displayed
in UI to the user and/or while execution.
"""
if not isinstance(d, dict) or Encoding.TYPE not in d:
return d
return BaseSerialization.deserialize(d)


def decode_asset(var: dict[str, Any]):
watchers = var.get("watchers", [])
return Asset(
name=var["name"],
Expand All @@ -361,7 +362,7 @@ def _smart_decode_trigger_kwargs(d):
name=watcher["name"],
trigger={
"classpath": watcher["trigger"]["classpath"],
"kwargs": _smart_decode_trigger_kwargs(watcher["trigger"]["kwargs"]),
"kwargs": smart_decode_trigger_kwargs(watcher["trigger"]["kwargs"]),
},
)
for watcher in watchers
Expand Down
45 changes: 45 additions & 0 deletions airflow-core/tests/unit/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,51 @@ async def test_invalid_trigger(self, supervisor_builder):
assert trigger_id == 1
assert traceback[-1] == "ModuleNotFoundError: No module named 'fake'\n"

@pytest.mark.asyncio
async def test_trigger_kwargs_serialization_cleanup(self, session):
"""
Test that trigger kwargs are properly cleaned of serialization artifacts
(__var, __type keys).
"""
from airflow.serialization.serialized_objects import BaseSerialization

kw = {"simple": "test", "tuple": (), "dict": {}, "list": []}

serialized_kwargs = BaseSerialization.serialize(kw)

trigger_orm = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs=serialized_kwargs)
session.add(trigger_orm)
session.commit()

stored_kwargs = trigger_orm.kwargs
assert stored_kwargs == {
"Encoding.TYPE": "dict",
"Encoding.VAR": {
"dict": {"Encoding.TYPE": "dict", "Encoding.VAR": {}},
"list": [],
"simple": "test",
"tuple": {"Encoding.TYPE": "tuple", "Encoding.VAR": []},
},
}

runner = TriggerRunner()
runner.to_create.append(
workloads.RunTrigger.model_construct(
id=trigger_orm.id,
ti=None,
classpath=trigger_orm.classpath,
encrypted_kwargs=trigger_orm.encrypted_kwargs,
)
)

await runner.create_triggers()
assert trigger_orm.id in runner.triggers
trigger_instance = runner.triggers[trigger_orm.id]["task"]

# The test passes if no exceptions were raised during trigger creation
trigger_instance.cancel()
await runner.cleanup_finished_triggers()


@pytest.mark.asyncio
async def test_trigger_create_race_condition_38599(session, supervisor_builder):
Expand Down
Loading