Skip to content

Commit a041a2a

Browse files
Enable Serde for Pydantic BaseModel and Subclasses (#51059)
This adds serialization and deserialization support for arbitrary pydantic objects, while still maintaining security. --------- Co-authored-by: Tzu-ping Chung <[email protected]>
1 parent 7f694ad commit a041a2a

File tree

13 files changed

+271
-56
lines changed

13 files changed

+271
-56
lines changed

airflow-core/src/airflow/serialization/serde.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import airflow.serialization.serializers
3434
from airflow.configuration import conf
35+
from airflow.serialization.typing import is_pydantic_model
3536
from airflow.stats import Stats
3637
from airflow.utils.module_loading import import_string, iter_namespace, qualname
3738

@@ -52,6 +53,7 @@
5253
OLD_SOURCE = "__source"
5354
OLD_DATA = "__var"
5455
OLD_DICT = "dict"
56+
PYDANTIC_MODEL_QUALNAME = "pydantic.main.BaseModel"
5557

5658
DEFAULT_VERSION = 0
5759

@@ -145,6 +147,12 @@ def serialize(o: object, depth: int = 0) -> U | None:
145147
qn = "builtins.tuple"
146148
classname = qn
147149

150+
if is_pydantic_model(o):
151+
# to match the generic Pydantic serializer and deserializer in _serializers and _deserializers
152+
qn = PYDANTIC_MODEL_QUALNAME
153+
# the actual Pydantic model class to encode
154+
classname = qualname(o)
155+
148156
# if there is a builtin serializer available use that
149157
if qn in _serializers:
150158
data, serialized_classname, version, is_serialized = _serializers[qn].serialize(o)
@@ -256,7 +264,10 @@ def deserialize(o: T | None, full=True, type_hint: Any = None) -> object:
256264

257265
# registered deserializer
258266
if classname in _deserializers:
259-
return _deserializers[classname].deserialize(classname, version, deserialize(value))
267+
return _deserializers[classname].deserialize(cls, version, deserialize(value))
268+
if is_pydantic_model(cls):
269+
if PYDANTIC_MODEL_QUALNAME in _deserializers:
270+
return _deserializers[PYDANTIC_MODEL_QUALNAME].deserialize(cls, version, deserialize(value))
260271

261272
# class has deserialization function
262273
if hasattr(cls, "deserialize"):

airflow-core/src/airflow/serialization/serializers/bignum.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
4747
return float(o), name, __version__, True
4848

4949

50-
def deserialize(classname: str, version: int, data: object) -> decimal.Decimal:
50+
def deserialize(cls: type, version: int, data: object) -> decimal.Decimal:
5151
from decimal import Decimal
5252

5353
if version > __version__:
54-
raise TypeError(f"serialized {version} of {classname} > {__version__}")
54+
raise TypeError(f"serialized {version} of {qualname(cls)} > {__version__}")
5555

56-
if classname != qualname(Decimal):
57-
raise TypeError(f"{classname} != {qualname(Decimal)}")
56+
if cls is not Decimal:
57+
raise TypeError(f"do not know how to deserialize {qualname(cls)}")
5858

5959
return Decimal(str(data))

airflow-core/src/airflow/serialization/serializers/builtin.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,20 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
3535
return list(cast("list", o)), qualname(o), __version__, True
3636

3737

38-
def deserialize(classname: str, version: int, data: list) -> tuple | set | frozenset:
38+
def deserialize(cls: type, version: int, data: list) -> tuple | set | frozenset:
3939
if version > __version__:
40-
raise TypeError("serialized version is newer than class version")
40+
raise TypeError(f"serialized version {version} is newer than class version {__version__}")
4141

42-
if classname == qualname(tuple):
42+
if cls is tuple:
4343
return tuple(data)
4444

45-
if classname == qualname(set):
45+
if cls is set:
4646
return set(data)
4747

48-
if classname == qualname(frozenset):
48+
if cls is frozenset:
4949
return frozenset(data)
5050

51-
raise TypeError(f"do not know how to deserialize {classname}")
51+
raise TypeError(f"do not know how to deserialize {qualname(cls)}")
5252

5353

5454
def stringify(classname: str, version: int, data: list) -> str:

airflow-core/src/airflow/serialization/serializers/datetime.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
5959
return "", "", 0, False
6060

6161

62-
def deserialize(classname: str, version: int, data: dict | str) -> datetime.date | datetime.timedelta:
62+
def deserialize(cls: type, version: int, data: dict | str) -> datetime.date | datetime.timedelta:
6363
import datetime
6464

6565
from pendulum import DateTime
@@ -86,16 +86,16 @@ def deserialize(classname: str, version: int, data: dict | str) -> datetime.date
8686
else None
8787
)
8888

89-
if classname == qualname(datetime.datetime) and isinstance(data, dict):
89+
if cls is datetime.datetime and isinstance(data, dict):
9090
return datetime.datetime.fromtimestamp(float(data[TIMESTAMP]), tz=tz)
9191

92-
if classname == qualname(DateTime) and isinstance(data, dict):
92+
if cls is DateTime and isinstance(data, dict):
9393
return DateTime.fromtimestamp(float(data[TIMESTAMP]), tz=tz)
9494

95-
if classname == qualname(datetime.timedelta) and isinstance(data, (str, float)):
95+
if cls is datetime.timedelta and isinstance(data, (str, float)):
9696
return datetime.timedelta(seconds=float(data))
9797

98-
if classname == qualname(datetime.date) and isinstance(data, str):
98+
if cls is datetime.date and isinstance(data, str):
9999
return datetime.date.fromisoformat(data)
100100

101-
raise TypeError(f"unknown date/time format {classname}")
101+
raise TypeError(f"unknown date/time format {qualname(cls)}")

airflow-core/src/airflow/serialization/serializers/deltalake.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,15 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
5555
return data, qualname(o), __version__, True
5656

5757

58-
def deserialize(classname: str, version: int, data: dict):
58+
def deserialize(cls: type, version: int, data: dict):
5959
from deltalake.table import DeltaTable
6060

6161
from airflow.models.crypto import get_fernet
6262

6363
if version > __version__:
6464
raise TypeError("serialized version is newer than class version")
6565

66-
if classname == qualname(DeltaTable):
66+
if cls is DeltaTable:
6767
fernet = get_fernet()
6868
properties = {}
6969
for k, v in data["storage_options"].items():
@@ -76,4 +76,4 @@ def deserialize(classname: str, version: int, data: dict):
7676

7777
return DeltaTable(data["table_uri"], version=data["version"], storage_options=storage_options)
7878

79-
raise TypeError(f"do not know how to deserialize {classname}")
79+
raise TypeError(f"do not know how to deserialize {qualname(cls)}")

airflow-core/src/airflow/serialization/serializers/iceberg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
5555
return data, qualname(o), __version__, True
5656

5757

58-
def deserialize(classname: str, version: int, data: dict):
58+
def deserialize(cls: type, version: int, data: dict):
5959
from pyiceberg.catalog import load_catalog
6060
from pyiceberg.table import Table
6161

@@ -64,7 +64,7 @@ def deserialize(classname: str, version: int, data: dict):
6464
if version > __version__:
6565
raise TypeError("serialized version is newer than class version")
6666

67-
if classname == qualname(Table):
67+
if cls is Table:
6868
fernet = get_fernet()
6969
properties = {}
7070
for k, v in data["catalog_properties"].items():
@@ -73,4 +73,4 @@ def deserialize(classname: str, version: int, data: dict):
7373
catalog = load_catalog(data["identifier"][0], **properties)
7474
return catalog.load_table((data["identifier"][1], data["identifier"][2]))
7575

76-
raise TypeError(f"do not know how to deserialize {classname}")
76+
raise TypeError(f"do not know how to deserialize {qualname(cls)}")

airflow-core/src/airflow/serialization/serializers/numpy.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,13 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
8080
return "", "", 0, False
8181

8282

83-
def deserialize(classname: str, version: int, data: str) -> Any:
83+
def deserialize(cls: type, version: int, data: str) -> Any:
8484
if version > __version__:
8585
raise TypeError("serialized version is newer than class version")
8686

87-
if classname not in deserializers:
88-
raise TypeError(f"unsupported {classname} found for numpy deserialization")
87+
allowed_deserialize_classes = [import_string(classname) for classname in deserializers]
8988

90-
return import_string(classname)(data)
89+
if cls not in allowed_deserialize_classes:
90+
raise TypeError(f"unsupported {qualname(cls)} found for numpy deserialization")
91+
92+
return cls(data)

airflow-core/src/airflow/serialization/serializers/pandas.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,22 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
5353
return buf.getvalue().hex().decode("utf-8"), qualname(o), __version__, True
5454

5555

56-
def deserialize(classname: str, version: int, data: object) -> pd.DataFrame:
56+
def deserialize(cls: type, version: int, data: object) -> pd.DataFrame:
5757
if version > __version__:
58-
raise TypeError(f"serialized {version} of {classname} > {__version__}")
58+
raise TypeError(f"serialized {version} of {qualname(cls)} > {__version__}")
5959

60-
from pyarrow import parquet as pq
60+
import pandas as pd
61+
62+
if cls is not pd.DataFrame:
63+
raise TypeError(f"do not know how to deserialize {qualname(cls)}")
6164

6265
if not isinstance(data, str):
63-
raise TypeError(f"serialized {classname} has wrong data type {type(data)}")
66+
raise TypeError(f"serialized {qualname(cls)} has wrong data type {type(data)}")
6467

6568
from io import BytesIO
6669

70+
from pyarrow import parquet as pq
71+
6772
with BytesIO(bytes.fromhex(data)) as buf:
6873
df = pq.read_table(buf).to_pandas()
6974

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
from __future__ import annotations
19+
20+
from typing import TYPE_CHECKING, cast
21+
22+
from airflow.serialization.typing import is_pydantic_model
23+
from airflow.utils.module_loading import qualname
24+
25+
if TYPE_CHECKING:
26+
from pydantic import BaseModel
27+
28+
from airflow.serialization.serde import U
29+
30+
serializers = [
31+
"pydantic.main.BaseModel",
32+
]
33+
deserializers = serializers
34+
35+
__version__ = 1
36+
37+
38+
def serialize(o: object) -> tuple[U, str, int, bool]:
39+
"""
40+
Serialize a Pydantic BaseModel instance into a dict of built-in types.
41+
42+
Returns a tuple of:
43+
- serialized data (as built-in types)
44+
- fixed class name for registration (BaseModel)
45+
- version number
46+
- is_serialized flag (True if handled)
47+
"""
48+
if not is_pydantic_model(o):
49+
return "", "", 0, False
50+
51+
model = cast("BaseModel", o) # for mypy
52+
data = model.model_dump()
53+
54+
return data, qualname(o), __version__, True
55+
56+
57+
def deserialize(cls: type, version: int, data: dict):
58+
"""
59+
Deserialize a Pydantic class.
60+
61+
Pydantic models can be serialized into a Python dictionary via `pydantic.main.BaseModel.model_dump`
62+
and the dictionary can be deserialized through `pydantic.main.BaseModel.model_validate`. This function
63+
can deserialize arbitrary Pydantic models that are in `allowed_deserialization_classes`.
64+
65+
:param cls: The actual model class
66+
:param version: Serialization version (must not exceed __version__)
67+
:param data: Dictionary with built-in types, typically from model_dump()
68+
:return: An instance of the actual Pydantic model
69+
"""
70+
if version > __version__:
71+
raise TypeError(f"Serialized version {version} is newer than the supported version {__version__}")
72+
73+
if not is_pydantic_model(cls):
74+
# no deserializer available
75+
raise TypeError(f"No deserializer found for {qualname(cls)}")
76+
77+
# Perform validation-based reconstruction
78+
model = cast("BaseModel", cls) # for mypy
79+
return model.model_validate(data)

airflow-core/src/airflow/serialization/serializers/timezone.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,16 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
6767
return "", "", 0, False
6868

6969

70-
def deserialize(classname: str, version: int, data: object) -> Any:
70+
def deserialize(cls: type, version: int, data: object) -> Any:
7171
from airflow.utils.timezone import parse_timezone
7272

7373
if not isinstance(data, (str, int)):
7474
raise TypeError(f"{data} is not of type int or str but of {type(data)}")
7575

7676
if version > __version__:
77-
raise TypeError(f"serialized {version} of {classname} > {__version__}")
77+
raise TypeError(f"serialized {version} of {qualname(cls)} > {__version__}")
7878

79-
if classname == "backports.zoneinfo.ZoneInfo" and isinstance(data, str):
79+
if qualname(cls) == "backports.zoneinfo.ZoneInfo" and isinstance(data, str):
8080
from zoneinfo import ZoneInfo
8181

8282
return ZoneInfo(data)

0 commit comments

Comments
 (0)