Skip to content

Commit bac97eb

Browse files
[Feature] JSON Serialize Method with C++ reflection (#335)
This PR provides functionality similar to #277, but leverages C++ reflection to eliminate the need for manually defining serialization/deserialization rules for each class. It builds on top of #334, and should be merged after #334 is approved. ### What's new in this PR: 1. Introduced `AutoJSONSerialize` and `AutoJSONDeserialize` to enable reflection-based JSON serialization. 2. Refactored `tokenizer_info.cc`: moved the definition of `TokenizerInfo::Impl` into a separate file to support serialization. 3. Added Python bindings for JSON serialization (via `JSONSerializer`) and corresponding tests.
1 parent e750120 commit bac97eb

28 files changed

+1460
-277
lines changed

3rdparty/picojson/picojson.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
#include <iostream>
5151
#include <iterator>
5252
#include <limits>
53-
#include <map>
5453
#include <stdexcept>
5554
#include <string>
5655
#include <unordered_map>
@@ -267,6 +266,7 @@ class object_with_ordered_keys : private std::unordered_map<std::string, value>
267266
using std::unordered_map<std::string, value>::at;
268267
using std::unordered_map<std::string, value>::count;
269268
using std::unordered_map<std::string, value>::find;
269+
using std::unordered_map<std::string, value>::reserve;
270270

271271
value& operator[](const std::string& key) {
272272
if (count(key) == 0) {

cpp/compiled_grammar_data_structure.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
// matcher_data_structure.h is included to use StackElement
1919
#include "persistent_stack.h"
2020
#include "support/dynamic_bitset.h"
21+
#include "support/reflection.h"
2122
#include "support/utils.h"
23+
#include "xgrammar/compiler.h"
2224

2325
namespace xgrammar {
2426

@@ -74,6 +76,15 @@ struct AdaptiveTokenMask {
7476
friend std::size_t MemorySize(const AdaptiveTokenMask& mask);
7577
};
7678

79+
XGRAMMAR_MEMBER_ARRAY(
80+
AdaptiveTokenMask,
81+
&AdaptiveTokenMask::store_type,
82+
&AdaptiveTokenMask::accepted_indices,
83+
&AdaptiveTokenMask::rejected_indices,
84+
&AdaptiveTokenMask::accepted_bitset,
85+
&AdaptiveTokenMask::uncertain_indices
86+
);
87+
7788
/*!
7889
* \brief All information that we need to match tokens in the tokenizer to the specified grammar.
7990
* It is the result of preprocessing.
@@ -119,8 +130,20 @@ class CompiledGrammar::Impl {
119130
TokenizerInfo GetTokenizerInfo() const { return tokenizer_info; }
120131

121132
std::size_t MemorySize() const;
133+
134+
friend struct member_trait<Impl>;
122135
};
123136

137+
XGRAMMAR_MEMBER_TABLE(
138+
CompiledGrammar::Impl,
139+
"grammar",
140+
&CompiledGrammar::Impl::grammar,
141+
"tokenizer_info",
142+
&CompiledGrammar::Impl::tokenizer_info,
143+
"adaptive_token_mask_cache",
144+
&CompiledGrammar::Impl::adaptive_token_mask_cache
145+
);
146+
124147
} // namespace xgrammar
125148

126149
#endif // XGRAMMAR_COMPILED_GRAMMAR_DATA_STRUCTURE_H_

cpp/fsm.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <cstring>
1313
#include <iostream>
1414
#include <list>
15+
#include <memory>
1516
#include <queue>
1617
#include <set>
1718
#include <string>
@@ -21,8 +22,11 @@
2122
#include <utility>
2223
#include <vector>
2324

25+
#include "picojson.h"
2426
#include "support/encoding.h"
27+
#include "support/json.h"
2528
#include "support/logging.h"
29+
#include "support/reflection.h"
2630
#include "support/union_find_set.h"
2731

2832
namespace xgrammar {
@@ -65,6 +69,7 @@ class FSMImplBase {
6569

6670
protected:
6771
ContainerType edges_;
72+
friend struct member_trait<CompactFSM::Impl>;
6873
};
6974

7075
template <typename ContainerType>
@@ -157,6 +162,8 @@ void FSMImplBase<ContainerType>::GetReachableStates(
157162

158163
class FSM::Impl : public FSMImplBase<std::vector<std::vector<FSMEdge>>> {
159164
public:
165+
Impl() = default;
166+
160167
Impl(int num_states = 0) { edges_.resize(num_states); }
161168

162169
using FSMImplBase<std::vector<std::vector<FSMEdge>>>::FSMImplBase;
@@ -372,6 +379,8 @@ class CompactFSM::Impl : public FSMImplBase<CSRArray<FSMEdge>> {
372379
friend std::size_t MemorySize(const Impl& self) { return MemorySize(self.edges_); }
373380
};
374381

382+
XGRAMMAR_MEMBER_ARRAY(CompactFSM::Impl, &CompactFSM::Impl::edges_);
383+
375384
int CompactFSM::Impl::GetNextState(int from, int16_t character) const {
376385
for (const auto& edge : edges_[from]) {
377386
if (edge.min == -1) {
@@ -1410,4 +1419,13 @@ FSMWithStartEnd CompactFSMWithStartEnd::ToFSM() const {
14101419
return FSMWithStartEnd(fsm_.ToFSM(), start_, ends_);
14111420
}
14121421

1422+
picojson::value CompactFSM::SerializeJSONValue() const { return AutoSerializeJSONValue(**this); }
1423+
1424+
void DeserializeJSONValue(CompactFSM& fsm, const picojson::value& v) {
1425+
if (!fsm.pimpl_) {
1426+
fsm.pimpl_ = std::make_unique<CompactFSM::Impl>();
1427+
}
1428+
return AutoDeserializeJSONValue(*fsm, v);
1429+
}
1430+
14131431
} // namespace xgrammar

cpp/fsm.h

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
#include <string>
1717
#include <unordered_map>
1818
#include <unordered_set>
19-
#include <variant>
2019
#include <vector>
2120

22-
#include "../cpp/support/csr_array.h"
21+
#include "picojson.h"
22+
#include "support/csr_array.h"
23+
#include "support/reflection.h"
2324

2425
namespace xgrammar {
2526

@@ -45,6 +46,9 @@ struct FSMEdge {
4546
<< "Invalid FSMEdge: min > max. min=" << min << ", max=" << max;
4647
}
4748

49+
// for serialization only
50+
FSMEdge() = default;
51+
4852
/*!
4953
* \brief Compare the edges. Used to sort the edges in the FSM.
5054
*/
@@ -84,6 +88,8 @@ struct FSMEdge {
8488
int GetRefRuleId() const { return IsRuleRef() ? max : -1; }
8589
};
8690

91+
XGRAMMAR_MEMBER_ARRAY(FSMEdge, &FSMEdge::min, &FSMEdge::max, &FSMEdge::target);
92+
8793
} // namespace xgrammar
8894

8995
namespace std {
@@ -268,6 +274,9 @@ class FSM {
268274
*/
269275
class CompactFSM {
270276
public:
277+
// for serialization only
278+
CompactFSM() = default;
279+
271280
CompactFSM(const CSRArray<FSMEdge>& edges);
272281

273282
CompactFSM(CSRArray<FSMEdge>&& edges);
@@ -346,9 +355,14 @@ class CompactFSM {
346355
*/
347356
FSM ToFSM() const;
348357

358+
picojson::value SerializeJSONValue() const;
359+
friend void DeserializeJSONValue(CompactFSM& fsm, const picojson::value& v);
360+
349361
XGRAMMAR_DEFINE_PIMPL_METHODS(CompactFSM);
350362
};
351363

364+
class CompactFSMWithStartEnd;
365+
352366
/*!
353367
* \brief The base class for FSMWithStartEnd and CompactFSMWithStartEnd. It defines the
354368
* common constructor and visitor methods.
@@ -361,6 +375,9 @@ class FSMWithStartEndBase {
361375
);
362376

363377
public:
378+
// for serialization only
379+
FSMWithStartEndBase() = default;
380+
364381
/*! \brief Constructs an FSMWithStartEnd with a given FSM, start state, and end states. */
365382
FSMWithStartEndBase(
366383
const FSMType& fsm, int start, const std::unordered_set<int>& ends, bool is_dfa = false
@@ -456,9 +473,9 @@ class FSMWithStartEndBase {
456473
std::unordered_set<int> ends_;
457474
/*! \brief Whether this FSM is a deterministic finite automaton. */
458475
bool is_dfa_ = false;
459-
};
460476

461-
class CompactFSMWithStartEnd;
477+
friend struct member_trait<CompactFSMWithStartEnd>;
478+
};
462479

463480
/*!
464481
* \brief FSMWithStartEnd represents a FSM with start and end states.
@@ -591,6 +608,9 @@ class CompactFSMWithStartEnd : public FSMWithStartEndBase<CompactFSM> {
591608
public:
592609
using FSMWithStartEndBase<CompactFSM>::FSMWithStartEndBase;
593610

611+
// for serialization only
612+
CompactFSMWithStartEnd() = default;
613+
594614
/*!
595615
* \brief Print the FSM.
596616
* \return The string representation of the FSM.
@@ -613,6 +633,14 @@ class CompactFSMWithStartEnd : public FSMWithStartEndBase<CompactFSM> {
613633
FSMWithStartEnd ToFSM() const;
614634
};
615635

636+
XGRAMMAR_MEMBER_ARRAY(
637+
CompactFSMWithStartEnd,
638+
&CompactFSMWithStartEnd::fsm_,
639+
&CompactFSMWithStartEnd::start_,
640+
&CompactFSMWithStartEnd::ends_,
641+
&CompactFSMWithStartEnd::is_dfa_
642+
);
643+
616644
/****************** FSMWithStartEndBase Template Implementation ******************/
617645

618646
template <typename FSMType>

cpp/grammar_data_structure.h

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "fsm.h"
1717
#include "support/logging.h"
18+
#include "support/reflection.h"
1819
#include "xgrammar/grammar.h"
1920

2021
namespace xgrammar {
@@ -175,14 +176,38 @@ class Grammar::Impl {
175176
std::vector<int32_t> allow_empty_rule_ids;
176177

177178
friend class GrammarBuilder;
178-
friend class GrammarSerializer;
179-
friend class GrammarDeserializer;
180179
friend class GrammarCompiler;
181180

182181
std::size_t MemorySize() const;
183182
friend std::size_t MemorySize(const Impl& impl);
183+
friend struct member_trait<Impl>;
184184
};
185185

186+
XGRAMMAR_MEMBER_ARRAY(
187+
Grammar::Impl::Rule,
188+
&Grammar::Impl::Rule::name,
189+
&Grammar::Impl::Rule::body_expr_id,
190+
&Grammar::Impl::Rule::lookahead_assertion_id
191+
);
192+
193+
XGRAMMAR_MEMBER_TABLE(
194+
Grammar::Impl,
195+
"rules_",
196+
&Grammar::Impl::rules_,
197+
"rule_expr_data_",
198+
&Grammar::Impl::rule_expr_data_,
199+
"rule_expr_indptr_",
200+
&Grammar::Impl::rule_expr_indptr_,
201+
"root_rule_id_",
202+
&Grammar::Impl::root_rule_id_,
203+
"root_tag_dispatch_fsm",
204+
&Grammar::Impl::root_tag_dispatch_fsm,
205+
"tag_dispatch_end_node_to_rule_id",
206+
&Grammar::Impl::tag_dispatch_end_node_to_rule_id,
207+
"allow_empty_rule_ids",
208+
&Grammar::Impl::allow_empty_rule_ids
209+
);
210+
186211
} // namespace xgrammar
187212

188213
#endif // XGRAMMAR_GRAMMAR_DATA_STRUCTURE_H_

cpp/grammar_functor.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,7 @@ class GrammarUnionFunctorImpl : public SubGrammarCombiner {
701701
}
702702

703703
// Avoid hiding the original Apply(const Grammar&)
704-
Grammar Apply(const Grammar& grammar) final { XGRAMMAR_LOG(FATAL) << "Should not be called"; }
704+
Grammar Apply(const Grammar&) final { XGRAMMAR_LOG(FATAL) << "Should not be called"; }
705705
};
706706

707707
/*!
@@ -735,7 +735,7 @@ class GrammarConcatFunctorImpl : public SubGrammarCombiner {
735735
}
736736

737737
// Avoid hiding the original Apply(const Grammar&)
738-
Grammar Apply(const Grammar& grammar) final { XGRAMMAR_LOG(FATAL) << "Should not be called"; }
738+
Grammar Apply(const Grammar&) final { XGRAMMAR_LOG(FATAL) << "Should not be called"; }
739739
};
740740

741741
/*!
@@ -955,7 +955,7 @@ class StructuralTagGrammarCreatorImpl : public SubGrammarCombiner {
955955
}
956956

957957
// Avoid hiding the original Apply(const Grammar&)
958-
Grammar Apply(const Grammar& grammar) final { XGRAMMAR_LOG(FATAL) << "Should not be called"; }
958+
Grammar Apply(const Grammar&) final { XGRAMMAR_LOG(FATAL) << "Should not be called"; }
959959
};
960960

961961
/*************************** Forward grammar functors to their impl ***************************/

cpp/grammar_serializer.cc

Lines changed: 0 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -140,82 +140,4 @@ std::string GrammarPrinter::ToString() {
140140
return result;
141141
}
142142

143-
std::string GrammarSerializer::Serialize() {
144-
picojson::object grammar_json_obj;
145-
146-
picojson::array rules_json;
147-
for (const auto& rule : grammar_->rules_) {
148-
picojson::object rule_json;
149-
rule_json["name"] = picojson::value(rule.name);
150-
rule_json["body_expr_id"] = picojson::value(static_cast<int64_t>(rule.body_expr_id));
151-
rules_json.push_back(picojson::value(rule_json));
152-
}
153-
grammar_json_obj["rules"] = picojson::value(rules_json);
154-
155-
picojson::array rule_expr_data_json;
156-
for (const auto& data : grammar_->rule_expr_data_) {
157-
rule_expr_data_json.push_back(picojson::value(static_cast<int64_t>(data)));
158-
}
159-
grammar_json_obj["rule_expr_data"] = picojson::value(rule_expr_data_json);
160-
picojson::array rule_expr_indptr_json;
161-
for (const auto& index_ptr : grammar_->rule_expr_indptr_) {
162-
rule_expr_indptr_json.push_back(picojson::value(static_cast<int64_t>(index_ptr)));
163-
}
164-
grammar_json_obj["rule_expr_indptr"] = picojson::value(rule_expr_indptr_json);
165-
166-
auto grammar_json = picojson::value(grammar_json_obj);
167-
return grammar_json.serialize(prettify_);
168-
}
169-
170-
Grammar GrammarDeserializer::Deserialize(std::string json_string) {
171-
auto node = std::make_shared<Grammar::Impl>();
172-
173-
auto checker = [&](bool condition) {
174-
XGRAMMAR_CHECK(condition) << "Failed to deserialize XGrammar object: " << json_string;
175-
};
176-
177-
picojson::value serialized_value;
178-
std::string err = picojson::parse(serialized_value, json_string);
179-
180-
checker(err.empty() && serialized_value.is<picojson::object>());
181-
auto serialized_obj = serialized_value.get<picojson::object>();
182-
183-
// rules
184-
checker(serialized_obj.count("rules") && serialized_obj["rules"].is<picojson::array>());
185-
auto rules_array = serialized_obj["rules"].get<picojson::array>();
186-
187-
checker(rules_array.size() > 0);
188-
for (const auto& rule_value : rules_array) {
189-
checker(rule_value.is<picojson::object>());
190-
auto rule_obj = rule_value.get<picojson::object>();
191-
checker(rule_obj.count("name") && rule_obj["name"].is<std::string>());
192-
auto name = rule_obj["name"].get<std::string>();
193-
checker(rule_obj.count("body_expr_id") && rule_obj["body_expr_id"].is<int64_t>());
194-
auto rule_expr = static_cast<int32_t>(rule_obj["body_expr_id"].get<int64_t>());
195-
node->rules_.push_back(Grammar::Impl::Rule({name, rule_expr}));
196-
}
197-
198-
// rule_expr_data
199-
checker(
200-
serialized_obj.count("rule_expr_data") &&
201-
serialized_obj["rule_expr_data"].is<picojson::array>()
202-
);
203-
auto rule_expr_data_array = serialized_obj["rule_expr_data"].get<picojson::array>();
204-
for (const auto& data_json : rule_expr_data_array) {
205-
node->rule_expr_data_.push_back(static_cast<int32_t>(data_json.get<int64_t>()));
206-
}
207-
208-
// rule_expr_indptr
209-
checker(
210-
serialized_obj.count("rule_expr_indptr") &&
211-
serialized_obj["rule_expr_indptr"].is<picojson::array>()
212-
);
213-
auto rule_expr_indptr_array = serialized_obj["rule_expr_indptr"].get<picojson::array>();
214-
for (const auto& index_ptr_json : rule_expr_indptr_array) {
215-
node->rule_expr_indptr_.push_back(static_cast<int32_t>(index_ptr_json.get<int64_t>()));
216-
}
217-
218-
return Grammar(node);
219-
}
220-
221143
} // namespace xgrammar

0 commit comments

Comments
 (0)