Skip to content

Commit bffc4a2

Browse files
authored
Handle extension property reinference in object initializers (#79739)
1 parent a2075fd commit bffc4a2

File tree

4 files changed

+319
-16
lines changed

4 files changed

+319
-16
lines changed

src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs

Lines changed: 88 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3561,8 +3561,6 @@ private void VisitLocalFunctionUse(LocalFunctionSymbol symbol)
35613561
SetResult(withExpr, resultState, resultType);
35623562
VisitObjectCreationInitializer(resultSlot, resultType.Type, withExpr.InitializerExpression, delayCompletionForType: false);
35633563

3564-
// Note: this does not account for the scenario where `Clone()` returns maybe-null and the with-expression has no initializers.
3565-
// Tracking in https://github.com/dotnet/roslyn/issues/44759
35663564
return null;
35673565
}
35683566

@@ -4408,18 +4406,55 @@ void setAnalyzedNullabilityAsContinuation(
44084406
};
44094407
}
44104408

4411-
static Symbol? getTargetMember(TypeSymbol containingType, BoundObjectInitializerMember objectInitializer)
4409+
Symbol? getTargetMember(TypeSymbol containingType, BoundObjectInitializerMember objectInitializer)
44124410
{
44134411
var symbol = objectInitializer.MemberSymbol;
4412+
if (symbol == null)
4413+
return null;
44144414

4415-
// https://github.com/dotnet/roslyn/issues/78828: adjust extension member based on new containing type
4416-
if (symbol != null && !symbol.GetIsNewExtensionMember())
4415+
if (!symbol.GetIsNewExtensionMember())
44174416
{
44184417
Debug.Assert(TypeSymbol.Equals(objectInitializer.Type, GetTypeOrReturnType(symbol), TypeCompareKind.IgnoreNullableModifiersForReferenceTypes));
4419-
symbol = AsMemberOfType(containingType, symbol);
4418+
return AsMemberOfType(containingType, symbol);
44204419
}
44214420

4422-
return symbol;
4421+
var extension = symbol.OriginalDefinition.ContainingType;
4422+
if (extension.Arity == 0)
4423+
{
4424+
return symbol;
4425+
}
4426+
4427+
if (symbol is not PropertySymbol { IsStatic: false } property
4428+
|| extension.ExtensionParameter is not { } extensionParameter)
4429+
{
4430+
Debug.Assert(objectInitializer.HasErrors);
4431+
return symbol;
4432+
}
4433+
4434+
var discardedUseSiteInfo = CompoundUseSiteInfo<AssemblySymbol>.Discarded;
4435+
4436+
// This may be incorrect for extension indexers.
4437+
// At that point we would maybe just want to gather arguments and visit the memberInitializer as an invocation of the indexer setter.
4438+
var inferenceResult = MethodTypeInferrer.Infer(
4439+
_binder,
4440+
_conversions,
4441+
extension.TypeParameters,
4442+
extension,
4443+
formalParameterTypes: [extensionParameter.TypeWithAnnotations],
4444+
formalParameterRefKinds: [extensionParameter.RefKind],
4445+
[new BoundExpressionWithNullability(objectInitializer.Syntax, objectInitializer, nullableAnnotation: NullableAnnotation.NotAnnotated, containingType)],
4446+
ref discardedUseSiteInfo,
4447+
new MethodInferenceExtensions(this),
4448+
ordinals: null);
4449+
4450+
if (inferenceResult.Success)
4451+
{
4452+
extension = extension.Construct(inferenceResult.InferredTypeArguments);
4453+
property = property.OriginalDefinition.AsMember(extension);
4454+
SetUpdatedSymbol(objectInitializer, symbol, property);
4455+
}
4456+
4457+
return property;
44234458
}
44244459

44254460
int getOrCreateSlot(int containingSlot, Symbol symbol)
@@ -8332,6 +8367,52 @@ private TMember InferMemberTypeArguments<TMember>(
83328367
return (TMember)(object)definition.ConstructIncludingExtension(result.InferredTypeArguments);
83338368
}
83348369

8370+
/// <remarks>
8371+
/// <para>This type assists with nullable reinference of generic types used in expressions.</para>
8372+
///
8373+
/// <para>
8374+
/// "Type argument inference" is the step we do during initial binding,
8375+
/// when producing the bound tree used as input to nullable analysis, lowering, and a bunch of other steps.
8376+
/// This inference is flow-independent, so, an expression used at any point in a method,
8377+
/// would get the same type arguments for the calls in it, assuming the stuff the expression is using is still in scope.
8378+
/// </para>
8379+
///
8380+
/// <para>
8381+
/// "Reinference" is done during nullable analysis, in order to enrich the results of
8382+
/// the initial type argument inference, based on the flow state of expressions at the particular point of usage.
8383+
/// </para>
8384+
///
8385+
/// <para>What it comes down to is scenarios like the following:</para>
8386+
///
8387+
/// <code>
8388+
/// var str = GetStringOrNull();
8389+
///
8390+
/// var arr1 = ImmutableArray.Create(str);
8391+
/// arr1[0].ToString(); // warning: possible null dereference
8392+
///
8393+
/// if (str == null)
8394+
/// return;
8395+
///
8396+
/// var arr2 = ImmutableArray.Create(str);
8397+
/// arr2[0].ToString(); // ok
8398+
/// </code>
8399+
///
8400+
/// <para>
8401+
/// For both calls to ImmutableArray.Create, initial binding will do a flow-independent type argument inference,
8402+
/// and both will receive type argument `string` (oblivious).
8403+
/// </para>
8404+
///
8405+
/// <para>
8406+
/// During nullable analysis, we will do a reinference of the call. The first call will get type argument `string?`.
8407+
/// The second call will get type argument `string` (non-nullable), based on the flow state of `str` at that point.
8408+
/// That needs to propagate to the next points in the control flow and be surfaced in public API for the types of the expressions and so on.
8409+
/// </para>
8410+
///
8411+
/// <para>
8412+
/// Reinference needs to be done on pretty much any expression which can represent a usage of either a generic method or a member of a generic type.
8413+
/// That ends up including calls (obviously) but also things like binary/unary operators, compound assignments, foreach statements, await-exprs, collection expression elements and so on.
8414+
/// </para>
8415+
/// </remarks>
83358416
private sealed class MethodInferenceExtensions : MethodTypeInferrer.Extensions
83368417
{
83378418
private readonly NullableWalker _walker;

src/Compilers/CSharp/Test/Emit3/Semantics/ExtensionTests.cs

Lines changed: 200 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48650,11 +48650,10 @@ public int Property { set { } }
4865048650
var comp = CreateCompilation(src);
4865148651
comp.VerifyEmitDiagnostics();
4865248652

48653-
// Tracked by https://github.com/dotnet/roslyn/issues/78828 : incorrect nullability
4865448653
var tree = comp.SyntaxTrees.Single();
4865548654
var model = comp.GetSemanticModel(tree);
4865648655
var assignment = GetSyntax<AssignmentExpressionSyntax>(tree, "Property = 42");
48657-
AssertEx.Equal("System.Int32 E.extension<System.Object>(System.Object).Property { set; }",
48656+
AssertEx.Equal("System.Int32 E.extension<System.Object!>(System.Object!).Property { set; }",
4865848657
model.GetSymbolInfo(assignment.Left).Symbol.ToTestDisplayString(includeNonNullable: true));
4865948658
}
4866048659

@@ -48679,20 +48678,141 @@ public T Property { set { } }
4867948678
}
4868048679
""";
4868148680
var comp = CreateCompilation(src);
48682-
comp.VerifyEmitDiagnostics();
48681+
comp.VerifyEmitDiagnostics(
48682+
// (4,31): warning CS8601: Possible null reference assignment.
48683+
// _ = new object() { Property = oNull };
48684+
Diagnostic(ErrorCode.WRN_NullReferenceAssignment, "oNull").WithLocation(4, 31));
4868348685

48684-
// Tracked by https://github.com/dotnet/roslyn/issues/78828 : incorrect nullability analysis for object initializer scenario with extension property
4868548686
var tree = comp.SyntaxTrees.Single();
4868648687
var model = comp.GetSemanticModel(tree);
4868748688
var assignment = GetSyntax<AssignmentExpressionSyntax>(tree, "Property = oNull");
48688-
AssertEx.Equal("System.Object E.extension<System.Object>(System.Object).Property { set; }",
48689+
AssertEx.Equal("System.Object! E.extension<System.Object!>(System.Object!).Property { set; }",
4868948690
model.GetSymbolInfo(assignment.Left).Symbol.ToTestDisplayString(includeNonNullable: true));
4869048691

4869148692
assignment = GetSyntax<AssignmentExpressionSyntax>(tree, "Property = oNotNull");
48692-
AssertEx.Equal("System.Object E.extension<System.Object>(System.Object).Property { set; }",
48693+
AssertEx.Equal("System.Object! E.extension<System.Object!>(System.Object!).Property { set; }",
4869348694
model.GetSymbolInfo(assignment.Left).Symbol.ToTestDisplayString(includeNonNullable: true));
4869448695
}
4869548696

48697+
[Fact]
48698+
public void Nullability_ObjectInitializer_04()
48699+
{
48700+
var src = """
48701+
#nullable enable
48702+
48703+
var s = "a";
48704+
Use(s, (new(s) { Property = null })/*T:C<string!>!*/); // 1
48705+
Use(s, (new(s) { Property = "a" })/*T:C<string!>!*/);
48706+
48707+
Use("a", (new(s) { Property = null })/*T:C<string!>!*/); // 2
48708+
Use("a", (new(s) { Property = "a" })/*T:C<string!>!*/);
48709+
48710+
Use(s, (new("a") { Property = null })/*T:C<string!>!*/); // 3
48711+
Use(s, (new("a") { Property = "a" })/*T:C<string!>!*/);
48712+
48713+
Use("a", (new("a") { Property = null })/*T:C<string!>!*/); // 4
48714+
Use("a", (new("a") { Property = "a" })/*T:C<string!>!*/);
48715+
48716+
if (s != null)
48717+
return;
48718+
48719+
Use(s, (new(s) { Property = null })/*T:C<string?>!*/);
48720+
Use(s, (new(s) { Property = "a" })/*T:C<string?>!*/);
48721+
48722+
Use("a", (new(s) { Property = null })/*T:C<string!>!*/); // 5, 6
48723+
Use("a", (new(s) { Property = "a" })/*T:C<string!>!*/);
48724+
48725+
Use(s, (new("a") { Property = null })/*T:C<string!>!*/); // 7
48726+
Use(s, (new("a") { Property = "a" })/*T:C<string!>!*/);
48727+
48728+
Use("a", (new("a") { Property = null })/*T:C<string!>!*/); // 8
48729+
Use("a", (new("a") { Property = "a" })/*T:C<string!>!*/);
48730+
48731+
void Use<T>(T value, C<T> c) => throw null!;
48732+
48733+
record C<T>(T Value) { }
48734+
48735+
static class E
48736+
{
48737+
extension<T>(C<T> c)
48738+
{
48739+
public T Property { set { } }
48740+
}
48741+
}
48742+
""";
48743+
48744+
var comp = CreateCompilation([src, IsExternalInitTypeDefinition]);
48745+
comp.VerifyTypes(comp.SyntaxTrees[0]);
48746+
comp.VerifyEmitDiagnostics(
48747+
// (4,29): warning CS8625: Cannot convert null literal to non-nullable reference type.
48748+
// Use(s, (new(s) { Property = null })/*T:C<string!>!*/); // 1
48749+
Diagnostic(ErrorCode.WRN_NullAsNonNullable, "null").WithLocation(4, 29),
48750+
// (7,31): warning CS8625: Cannot convert null literal to non-nullable reference type.
48751+
// Use("a", (new(s) { Property = null })/*T:C<string!>!*/); // 2
48752+
Diagnostic(ErrorCode.WRN_NullAsNonNullable, "null").WithLocation(7, 31),
48753+
// (10,31): warning CS8625: Cannot convert null literal to non-nullable reference type.
48754+
// Use(s, (new("a") { Property = null })/*T:C<string!>!*/); // 3
48755+
Diagnostic(ErrorCode.WRN_NullAsNonNullable, "null").WithLocation(10, 31),
48756+
// (13,33): warning CS8625: Cannot convert null literal to non-nullable reference type.
48757+
// Use("a", (new("a") { Property = null })/*T:C<string!>!*/); // 4
48758+
Diagnostic(ErrorCode.WRN_NullAsNonNullable, "null").WithLocation(13, 33),
48759+
// (22,15): warning CS8604: Possible null reference argument for parameter 'Value' in 'C<string>.C(string Value)'.
48760+
// Use("a", (new(s) { Property = null })/*T:C<string!>!*/); // 5, 6
48761+
Diagnostic(ErrorCode.WRN_NullReferenceArgument, "s").WithArguments("Value", "C<string>.C(string Value)").WithLocation(22, 15),
48762+
// (22,31): warning CS8625: Cannot convert null literal to non-nullable reference type.
48763+
// Use("a", (new(s) { Property = null })/*T:C<string!>!*/); // 5, 6
48764+
Diagnostic(ErrorCode.WRN_NullAsNonNullable, "null").WithLocation(22, 31),
48765+
// (25,31): warning CS8625: Cannot convert null literal to non-nullable reference type.
48766+
// Use(s, (new("a") { Property = null })/*T:C<string!>!*/); // 7
48767+
Diagnostic(ErrorCode.WRN_NullAsNonNullable, "null").WithLocation(25, 31),
48768+
// (28,33): warning CS8625: Cannot convert null literal to non-nullable reference type.
48769+
// Use("a", (new("a") { Property = null })/*T:C<string!>!*/); // 8
48770+
Diagnostic(ErrorCode.WRN_NullAsNonNullable, "null").WithLocation(28, 33));
48771+
48772+
}
48773+
48774+
[Fact]
48775+
public void Nullability_ObjectInitializer_05()
48776+
{
48777+
var src = """
48778+
#nullable enable
48779+
48780+
var s = "a";
48781+
48782+
Create(s).Use(new() { Property = null }); // 1
48783+
Create(s).Use(new() { Property = "a" });
48784+
48785+
if (s != null)
48786+
return;
48787+
48788+
Create(s).Use(new() { Property = null });
48789+
Create(s).Use(new() { Property = "a" });
48790+
48791+
Consumer<T> Create<T>(T value) => throw null!;
48792+
48793+
class Consumer<T>
48794+
{
48795+
public void Use(C<T> c) => throw null!;
48796+
}
48797+
48798+
record C<T> { }
48799+
48800+
static class E
48801+
{
48802+
extension<T>(C<T> c)
48803+
{
48804+
public T Property { set { } }
48805+
}
48806+
}
48807+
""";
48808+
48809+
var comp = CreateCompilation([src, IsExternalInitTypeDefinition]);
48810+
comp.VerifyEmitDiagnostics(
48811+
// (5,34): warning CS8625: Cannot convert null literal to non-nullable reference type.
48812+
// Create(s).Use(new() { Property = null }); // 1
48813+
Diagnostic(ErrorCode.WRN_NullAsNonNullable, "null").WithLocation(5, 34));
48814+
}
48815+
4869648816
[Fact]
4869748817
public void Nullability_With_01()
4869848818
{
@@ -48776,7 +48896,6 @@ public int Property { set { } }
4877648896
}
4877748897
""";
4877848898

48779-
// Tracked by https://github.com/dotnet/roslyn/issues/78828 : incorrect nullability analysis for with-expression with extension property (unexpected warning)
4878048899
var comp = CreateCompilation(src);
4878148900
comp.VerifyEmitDiagnostics(
4878248901
// (4,5): warning CS8602: Dereference of a possibly null reference.
@@ -48806,7 +48925,7 @@ public int Property { set { } }
4880648925
}
4880748926
}
4880848927
""";
48809-
// Tracked by https://github.com/dotnet/roslyn/issues/78828 : unexpected nullability warning
48928+
4881048929
var comp = CreateCompilation(src);
4881148930
comp.VerifyEmitDiagnostics(
4881248931
// (4,5): warning CS8602: Dereference of a possibly null reference.
@@ -48844,6 +48963,79 @@ public int Property { set { } }
4884448963
Diagnostic(ErrorCode.WRN_NullReferenceReceiver, "cNull").WithLocation(4, 5));
4884548964
}
4884648965

48966+
[Fact]
48967+
public void Nullability_With_06()
48968+
{
48969+
var src = """
48970+
#nullable enable
48971+
48972+
C? cNull = null;
48973+
_ = cNull with { Property = null };
48974+
48975+
C cNotNull = new C();
48976+
_ = cNotNull with { Property = null };
48977+
48978+
record C { }
48979+
48980+
static class E
48981+
{
48982+
extension(object? o)
48983+
{
48984+
public string Property { set { } }
48985+
}
48986+
}
48987+
""";
48988+
48989+
var comp = CreateCompilation(src);
48990+
comp.VerifyEmitDiagnostics(
48991+
// (4,5): warning CS8602: Dereference of a possibly null reference.
48992+
// _ = cNull with { Property = null };
48993+
Diagnostic(ErrorCode.WRN_NullReferenceReceiver, "cNull").WithLocation(4, 5),
48994+
// (4,29): warning CS8625: Cannot convert null literal to non-nullable reference type.
48995+
// _ = cNull with { Property = null };
48996+
Diagnostic(ErrorCode.WRN_NullAsNonNullable, "null").WithLocation(4, 29),
48997+
// (7,32): warning CS8625: Cannot convert null literal to non-nullable reference type.
48998+
// _ = cNotNull with { Property = null };
48999+
Diagnostic(ErrorCode.WRN_NullAsNonNullable, "null").WithLocation(7, 32));
49000+
}
49001+
49002+
[Fact]
49003+
public void Nullability_With_07()
49004+
{
49005+
var src = """
49006+
#nullable enable
49007+
49008+
var s = "a";
49009+
var c1 = Create(s);
49010+
c1 = c1 with { Property = null }; // 1
49011+
c1 = c1 with { Property = "a" };
49012+
49013+
if (s != null)
49014+
return;
49015+
49016+
var c2 = Create(s);
49017+
c2 = c2 with { Property = null }; // ok
49018+
c2 = c2 with { Property = "a" };
49019+
C<T> Create<T>(T value) => throw null!;
49020+
49021+
record C<T> { }
49022+
49023+
static class E
49024+
{
49025+
extension<T>(C<T> c)
49026+
{
49027+
public T Property { set { } }
49028+
}
49029+
}
49030+
""";
49031+
49032+
var comp = CreateCompilation(src);
49033+
comp.VerifyEmitDiagnostics(
49034+
// (5,27): warning CS8625: Cannot convert null literal to non-nullable reference type.
49035+
// c1 = c1 with { Property = null }; // 1
49036+
Diagnostic(ErrorCode.WRN_NullAsNonNullable, "null").WithLocation(5, 27));
49037+
}
49038+
4884749039
[Fact]
4884849040
public void Nullability_Fixed_01()
4884949041
{

0 commit comments

Comments
 (0)