Skip to content

Commit 27be097

Browse files
dwblaikiezygoloid
andauthored
Vtable support for generics (#5793)
Some specific features: * Use `SpecificFunction` for vtable entries for generic classes. * Create specific constants for vtable entries in classes derived from generic classes to reference the appropriate specific of the function in the context of such a derived class. * Create specific constants for vtable_ptrs for uses of specific generic classes. --------- Co-authored-by: Richard Smith <[email protected]>
1 parent 8cd1307 commit 27be097

File tree

12 files changed

+1044
-123
lines changed

12 files changed

+1044
-123
lines changed

toolchain/check/class.cpp

Lines changed: 83 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,6 @@
1414

1515
namespace Carbon::Check {
1616

17-
auto TryGetAsClass(Context& context, SemIR::TypeId type_id) -> SemIR::Class* {
18-
auto class_type = context.types().TryGetAs<SemIR::ClassType>(type_id);
19-
if (!class_type) {
20-
return nullptr;
21-
}
22-
return &context.classes().Get(class_type->class_id);
23-
}
24-
2517
auto SetClassSelfType(Context& context, SemIR::ClassId class_id) -> void {
2618
auto& class_info = context.classes().Get(class_id);
2719
auto specific_id = context.generics().GetSelfSpecific(class_info.generic_id);
@@ -131,10 +123,51 @@ static auto AddStructTypeFields(
131123

132124
// Builds and returns a vtable for the current class. Assumes that the virtual
133125
// functions for the class are listed as the top element of the `vtable_stack`.
134-
static auto BuildVtable(Context& context, SemIR::ClassId class_id,
135-
SemIR::VtableId base_vtable_id,
126+
static auto BuildVtable(Context& context, Parse::ClassDefinitionId node_id,
127+
SemIR::ClassId class_id,
128+
std::optional<SemIR::ClassType> base_class_type,
136129
llvm::ArrayRef<SemIR::InstId> vtable_contents)
137130
-> SemIR::VtableId {
131+
auto base_vtable_id = SemIR::VtableId::None;
132+
auto base_class_specific_id = SemIR::SpecificId::None;
133+
134+
// Get some base class/type/specific info.
135+
if (base_class_type) {
136+
auto& base_class_info = context.classes().Get(base_class_type->class_id);
137+
auto base_vtable_ptr_inst_id = base_class_info.vtable_ptr_id;
138+
if (base_vtable_ptr_inst_id.has_value()) {
139+
LoadImportRef(context, base_vtable_ptr_inst_id);
140+
auto canonical_base_vtable_inst_id =
141+
context.constant_values().GetConstantInstId(base_vtable_ptr_inst_id);
142+
const auto& base_vtable_ptr_inst =
143+
context.insts().GetAs<SemIR::VtablePtr>(
144+
canonical_base_vtable_inst_id);
145+
base_vtable_id = base_vtable_ptr_inst.vtable_id;
146+
base_class_specific_id = base_class_type->specific_id;
147+
}
148+
}
149+
150+
const auto& class_info = context.classes().Get(class_id);
151+
auto class_generic_id = class_info.generic_id;
152+
153+
// Wrap vtable entries in SpecificFunctions as needed/in generic classes.
154+
auto build_specific_function =
155+
[&](SemIR::InstId fn_decl_id) -> SemIR::InstId {
156+
if (!class_generic_id.has_value()) {
157+
return fn_decl_id;
158+
}
159+
const auto& fn_decl =
160+
context.insts().GetAs<SemIR::FunctionDecl>(fn_decl_id);
161+
const auto& function = context.functions().Get(fn_decl.function_id);
162+
return GetOrAddInst<SemIR::SpecificFunction>(
163+
context, node_id,
164+
{.type_id =
165+
GetSingletonType(context, SemIR::SpecificFunctionType::TypeInstId),
166+
.callee_id = fn_decl_id,
167+
.specific_id =
168+
context.generics().GetSelfSpecific(function.generic_id)});
169+
};
170+
138171
llvm::SmallVector<SemIR::InstId> vtable;
139172
if (base_vtable_id.has_value()) {
140173
auto base_vtable_inst_block = context.inst_blocks().Get(
@@ -144,25 +177,36 @@ static auto BuildVtable(Context& context, SemIR::ClassId class_id,
144177
for (auto fn_decl_id : base_vtable_inst_block) {
145178
auto fn_decl = GetCalleeFunction(context.sem_ir(), fn_decl_id);
146179
const auto& fn = context.functions().Get(fn_decl.function_id);
147-
for (auto override_fn_decl_id : vtable_contents) {
148-
auto override_fn_decl =
149-
context.insts().GetAs<SemIR::FunctionDecl>(override_fn_decl_id);
150-
auto& override_fn =
151-
context.functions().Get(override_fn_decl.function_id);
152-
if (override_fn.virtual_modifier ==
153-
SemIR::FunctionFields::VirtualModifier::Impl &&
154-
override_fn.name_id == fn.name_id) {
155-
// TODO: Support generic base classes, rather than passing
156-
// `SpecificId::None`.
157-
CheckFunctionTypeMatches(context, override_fn, fn,
158-
SemIR::SpecificId::None,
159-
/*check_syntax=*/false,
160-
/*check_self=*/false);
161-
fn_decl_id = override_fn_decl_id;
162-
override_fn.virtual_index = vtable.size();
163-
CARBON_CHECK(override_fn.virtual_index == fn.virtual_index);
164-
break;
165-
}
180+
const auto* i = llvm::find_if(
181+
vtable_contents, [&](SemIR::InstId override_fn_decl_id) -> bool {
182+
const auto& override_fn = context.functions().Get(
183+
context.insts()
184+
.GetAs<SemIR::FunctionDecl>(override_fn_decl_id)
185+
.function_id);
186+
return override_fn.virtual_modifier ==
187+
SemIR::FunctionFields::VirtualModifier::Impl &&
188+
override_fn.name_id == fn.name_id;
189+
});
190+
if (i != vtable_contents.end()) {
191+
auto& override_fn = context.functions().Get(
192+
context.insts().GetAs<SemIR::FunctionDecl>(*i).function_id);
193+
// TODO: Support generic base classes, rather than passing
194+
// `SpecificId::None`. This'll need to `GetConstantValueInSpecific` for
195+
// the base function, then extract the specific from that for use here.
196+
CheckFunctionTypeMatches(context, override_fn, fn,
197+
SemIR::SpecificId::None,
198+
/*check_syntax=*/false,
199+
/*check_self=*/false);
200+
fn_decl_id = build_specific_function(*i);
201+
override_fn.virtual_index = vtable.size();
202+
CARBON_CHECK(override_fn.virtual_index == fn.virtual_index);
203+
} else {
204+
// Remap the base's vtable entry to the appropriate constant usable in
205+
// the context of the derived class (for the specific for the base
206+
// class, for instance)..
207+
fn_decl_id = context.sem_ir().constant_values().GetInstId(
208+
GetConstantValueInSpecific(context.sem_ir(), base_class_specific_id,
209+
fn_decl_id));
166210
}
167211
vtable.push_back(fn_decl_id);
168212
}
@@ -173,7 +217,7 @@ static auto BuildVtable(Context& context, SemIR::ClassId class_id,
173217
auto& fn = context.functions().Get(fn_decl.function_id);
174218
if (fn.virtual_modifier != SemIR::FunctionFields::VirtualModifier::Impl) {
175219
fn.virtual_index = vtable.size();
176-
vtable.push_back(inst_id);
220+
vtable.push_back(build_specific_function(inst_id));
177221
}
178222
}
179223

@@ -200,12 +244,13 @@ static auto CheckCompleteClassType(
200244
class_info.GetBaseType(context.sem_ir(), SemIR::SpecificId::None);
201245
// TODO: Use InstId from base declaration.
202246
auto base_type_inst_id = context.types().GetInstId(base_type_id);
203-
SemIR::Class* base_class_info = nullptr;
247+
std::optional<SemIR::ClassType> base_class_type;
204248
if (base_type_id.has_value()) {
205249
// TODO: If the base class is template dependent, we will need to decide
206250
// whether to add a vptr as part of instantiation.
207-
base_class_info = TryGetAsClass(context, base_type_id);
208-
if (base_class_info && base_class_info->is_dynamic) {
251+
base_class_type = context.types().TryGetAs<SemIR::ClassType>(base_type_id);
252+
if (base_class_type &&
253+
context.classes().Get(base_class_type->class_id).is_dynamic) {
209254
defining_vptr = false;
210255
}
211256
}
@@ -229,35 +274,21 @@ static auto CheckCompleteClassType(
229274
}
230275

231276
if (class_info.is_dynamic) {
232-
SemIR::VtableId base_vtable_id = SemIR::VtableId::None;
233-
if (base_class_info) {
234-
auto base_vtable_ptr_inst_id = base_class_info->vtable_ptr_id;
235-
if (base_vtable_ptr_inst_id.has_value()) {
236-
LoadImportRef(context, base_vtable_ptr_inst_id);
237-
auto canonical_base_vtable_inst_id =
238-
context.constant_values().GetConstantInstId(
239-
base_vtable_ptr_inst_id);
240-
const auto& base_vtable_ptr_inst =
241-
context.insts().GetAs<SemIR::VtablePtr>(
242-
canonical_base_vtable_inst_id);
243-
base_vtable_id = base_vtable_ptr_inst.vtable_id;
244-
// TODO: Retrieve the specific_id from the base_vtable_ptr_inst here,
245-
// for use in BuildVtable.
246-
}
247-
}
248-
auto vtable_id =
249-
BuildVtable(context, class_id, base_vtable_id, vtable_contents);
277+
auto vtable_id = BuildVtable(context, node_id, class_id, base_class_type,
278+
vtable_contents);
250279

251280
auto vptr_type_id = GetPointerType(context, SemIR::VtableType::TypeInstId);
252281
// TODO: Handle specifics here, probably passing
253282
// `context.generics().GetSelfSpecific(class_info.generic_id)` as the
254283
// specific_id here (but more work involved to get this all plumbed in and
255284
// tested).
285+
auto generic_id = class_info.generic_id;
286+
auto self_specific_id = context.generics().GetSelfSpecific(generic_id);
256287
class_info.vtable_ptr_id =
257288
AddInst<SemIR::VtablePtr>(context, node_id,
258289
{.type_id = vptr_type_id,
259290
.vtable_id = vtable_id,
260-
.specific_id = SemIR::SpecificId::None});
291+
.specific_id = self_specific_id});
261292
}
262293

263294
auto struct_type_inst_id = AddTypeInst<SemIR::StructType>(

toolchain/check/class.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@
99

1010
namespace Carbon::Check {
1111

12-
// If `type_id` is a class type, get its corresponding `SemIR::Class` object.
13-
// Otherwise returns `nullptr`.
14-
auto TryGetAsClass(Context& context, SemIR::TypeId type_id) -> SemIR::Class*;
15-
1612
// Sets the `Self` type for the class.
1713
auto SetClassSelfType(Context& context, SemIR::ClassId class_id) -> void;
1814

toolchain/check/convert.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,12 @@ static auto ConvertStructToClass(
589589

590590
if (!dest_vtable_ptr_inst_id.has_value()) {
591591
dest_vtable_ptr_inst_id = dest_class_info.vtable_ptr_id;
592+
if (dest_type.specific_id.has_value() &&
593+
dest_vtable_ptr_inst_id.has_value()) {
594+
dest_vtable_ptr_inst_id = context.constant_values().GetInstId(
595+
GetConstantValueInSpecific(context.sem_ir(), dest_type.specific_id,
596+
dest_vtable_ptr_inst_id));
597+
}
592598
}
593599

594600
if (dest_vtable_ptr_inst_id.has_value()) {

toolchain/check/handle_class.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -459,25 +459,28 @@ static auto CheckBaseType(Context& context, Parse::NodeId node_id,
459459
return BaseInfo::Error;
460460
}
461461

462-
auto* base_class_info = TryGetAsClass(context, base_type_id);
462+
auto class_type = context.types().TryGetAs<SemIR::ClassType>(base_type_id);
463463

464464
// The base must not be a final class.
465-
if (!base_class_info) {
465+
if (!class_type) {
466466
// For now, we treat all types that aren't introduced by a `class`
467467
// declaration as being final classes.
468468
// TODO: Once we have a better idea of which types are considered to be
469469
// classes, produce a better diagnostic for deriving from a non-class type.
470470
DiagnoseBaseIsFinal(context, node_id, base_type_inst_id);
471471
return BaseInfo::Error;
472472
}
473-
if (base_class_info->inheritance_kind == SemIR::Class::Final) {
473+
474+
const auto& base_class_info = context.classes().Get(class_type->class_id);
475+
476+
if (base_class_info.inheritance_kind == SemIR::Class::Final) {
474477
DiagnoseBaseIsFinal(context, node_id, base_type_inst_id);
475478
}
476479

477-
CARBON_CHECK(base_class_info->scope_id.has_value(),
480+
CARBON_CHECK(base_class_info.scope_id.has_value(),
478481
"Complete class should have a scope");
479482
return {.type_id = base_type_id,
480-
.scope_id = base_class_info->scope_id,
483+
.scope_id = base_class_info.scope_id,
481484
.inst_id = base_type_inst_id};
482485
}
483486

0 commit comments

Comments
 (0)