Skip to content

Commit fe70d58

Browse files
committed
Implement support for SQL Server vector search
Closes #34659
1 parent 121e69a commit fe70d58

File tree

17 files changed

+681
-205
lines changed

17 files changed

+681
-205
lines changed

src/EFCore.SqlServer/EFCore.SqlServer.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050

5151
<ItemGroup>
5252
<PackageReference Include="Microsoft.Data.SqlClient" />
53+
<PackageReference Include="Microsoft.SqlServer.Types" />
5354
</ItemGroup>
5455

5556
<ItemGroup>

src/EFCore.SqlServer/Extensions/SqlServerDbFunctionsExtensions.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
// ReSharper disable once CheckNamespace
55

6+
using Microsoft.Data.SqlTypes;
7+
68
namespace Microsoft.EntityFrameworkCore;
79

810
/// <summary>
@@ -2452,4 +2454,31 @@ public static long PatIndex(
24522454
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VariancePopulation)));
24532455

24542456
#endregion Population variance
2457+
2458+
#region Vector functions
2459+
2460+
/// <summary>
2461+
/// Calculates the distance between two vectors using a specified distance metric.
2462+
/// </summary>
2463+
/// <param name="_">The <see cref="DbFunctions" /> instance.</param>
2464+
/// <param name="distanceMetric">
2465+
/// A string with the name of the distance metric to use to calculate the distance between the two given vectors. The following distance metrics are supported: <c>cosine</c>, <c>euclidean</c> or <c>dot</c>.
2466+
/// </param>
2467+
/// <param name="vector1">The first vector.</param>
2468+
/// <param name="vector2">The second vector.</param>
2469+
/// <remarks>
2470+
/// Vector distance is always exact and doesn't use any vector index, even if available.
2471+
/// TODO: In order to use a vector index and thus perform an approximate vector search, you must use the VECTOR_SEARCH function.
2472+
/// </remarks>
2473+
/// <seealso href="https://learn.microsoft.com/sql/t-sql/functions/vector-distance-transact-sql">SQL Server documentation for <c>VECTOR_DISTANCE</c>.</seealso>
2474+
/// <seealso href="https://learn.microsoft.com/sql/relational-databases/vectors/vectors-sql-server">Vectors in the SQL Database Engine.</seealso>
2475+
public static double VectorDistance<T>(
2476+
this DbFunctions _,
2477+
[NotParameterized] string distanceMetric,
2478+
SqlVector<T> vector1, // TODO: Support other types, e.g. float[]?
2479+
SqlVector<T> vector2)
2480+
where T : unmanaged
2481+
=> throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(VectorDistance)));
2482+
2483+
#endregion Vector functions
24552484
}

src/EFCore.SqlServer/Infrastructure/Internal/SqlServerModelValidator.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System.Text;
5+
using Microsoft.Data.SqlTypes;
56
using Microsoft.EntityFrameworkCore.Metadata.Internal;
67
using Microsoft.EntityFrameworkCore.SqlServer.Extensions.Internal;
78
using Microsoft.EntityFrameworkCore.SqlServer.Internal;
89
using Microsoft.EntityFrameworkCore.SqlServer.Metadata.Internal;
10+
using Microsoft.EntityFrameworkCore.SqlServer.Storage.Internal;
911

1012
namespace Microsoft.EntityFrameworkCore.SqlServer.Infrastructure.Internal;
1113

@@ -43,6 +45,7 @@ public override void Validate(IModel model, IDiagnosticsLogger<DbLoggerCategory.
4345
base.Validate(model, logger);
4446

4547
ValidateDecimalColumns(model, logger);
48+
ValidateVectorColumns(model, logger);
4649
ValidateByteIdentityMapping(model, logger);
4750
ValidateTemporalTables(model, logger);
4851
ValidateUseOfJsonType(model, logger);
@@ -110,6 +113,32 @@ protected virtual void ValidateDecimalColumns(
110113
}
111114
}
112115

116+
/// <summary>
117+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
118+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
119+
/// any release. You should only use it directly in your code with extreme caution and knowing that
120+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
121+
/// </summary>
122+
protected virtual void ValidateVectorColumns(
123+
IModel model,
124+
IDiagnosticsLogger<DbLoggerCategory.Model.Validation> logger)
125+
{
126+
foreach (IConventionProperty property in model.GetEntityTypes()
127+
.SelectMany(t => t.GetDeclaredProperties())
128+
.Where(p => p.ClrType.UnwrapNullableType() == typeof(SqlVector<float>)))
129+
{
130+
if (property.GetTypeMapping() is not SqlServerVectorTypeMapping { Size: not null } vectorTypeMapping)
131+
{
132+
throw new InvalidOperationException(SqlServerStrings.VectorDimensionsMissing(property.DeclaringType.DisplayName(), property.Name));
133+
}
134+
135+
if (property.DeclaringType.IsMappedToJson())
136+
{
137+
throw new InvalidOperationException(SqlServerStrings.VectorPropertiesNotSupportedInJson(property.DeclaringType.DisplayName(), property.Name));
138+
}
139+
}
140+
}
141+
113142
/// <summary>
114143
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
115144
/// the same compatibility standards as public APIs. It may be changed or removed without notice in

src/EFCore.SqlServer/Properties/SqlServerStrings.Designer.cs

Lines changed: 22 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/EFCore.SqlServer/Properties/SqlServerStrings.resx

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,4 +369,13 @@
369369
<data name="TransientExceptionDetected" xml:space="preserve">
370370
<value>An exception has been raised that is likely due to a transient failure. Consider enabling transient error resiliency by adding 'EnableRetryOnFailure' to the 'UseSqlServer' call.</value>
371371
</data>
372+
<data name="VectorDimensionsInvalid" xml:space="preserve">
373+
<value>Vector properties require a positive size (number of dimensions).</value>
374+
</data>
375+
<data name="VectorDimensionsMissing" xml:space="preserve">
376+
<value>Vector property '{structuralType}.{propertyName}' was not configured with the number of dimensions. Set the column type to 'vector(x)' with the desired number of dimensions, or use the 'MaxLength' APIs.</value>
377+
</data>
378+
<data name="VectorPropertiesNotSupportedInJson" xml:space="preserve">
379+
<value>Vector property '{propertyName}' is on '{structuralType}' which is mapped to JSON. Vector properties are not supported within JSON documents.</value>
380+
</data>
372381
</root>

src/EFCore.SqlServer/Query/Internal/SqlServerMemberTranslatorProvider.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ public SqlServerMemberTranslatorProvider(
3030
new SqlServerDateTimeMemberTranslator(sqlExpressionFactory, typeMappingSource),
3131
new SqlServerStringMemberTranslator(sqlExpressionFactory),
3232
new SqlServerTimeSpanMemberTranslator(sqlExpressionFactory),
33-
new SqlServerTimeOnlyMemberTranslator(sqlExpressionFactory)
33+
new SqlServerTimeOnlyMemberTranslator(sqlExpressionFactory),
34+
new SqlServerVectorTranslator(sqlExpressionFactory, typeMappingSource)
3435
]);
3536
}
3637
}

src/EFCore.SqlServer/Query/Internal/SqlServerMethodCallTranslatorProvider.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ public SqlServerMethodCallTranslatorProvider(
4242
new SqlServerNewGuidTranslator(sqlExpressionFactory),
4343
new SqlServerObjectToStringTranslator(sqlExpressionFactory, typeMappingSource),
4444
new SqlServerStringMethodTranslator(sqlExpressionFactory, sqlServerSingletonOptions),
45-
new SqlServerTimeOnlyMethodTranslator(sqlExpressionFactory)
45+
new SqlServerTimeOnlyMethodTranslator(sqlExpressionFactory),
46+
new SqlServerVectorTranslator(sqlExpressionFactory, typeMappingSource)
4647
]);
4748
}
4849
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using Microsoft.Data.SqlTypes;
5+
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
6+
7+
// ReSharper disable once CheckNamespace
8+
namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;
9+
10+
/// <summary>
11+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
12+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
13+
/// any release. You should only use it directly in your code with extreme caution and knowing that
14+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
15+
/// </summary>
16+
public class SqlServerVectorTranslator(
17+
ISqlExpressionFactory sqlExpressionFactory,
18+
IRelationalTypeMappingSource typeMappingSource)
19+
: IMethodCallTranslator, IMemberTranslator
20+
{
21+
/// <summary>
22+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
23+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
24+
/// any release. You should only use it directly in your code with extreme caution and knowing that
25+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
26+
/// </summary>
27+
public SqlExpression? Translate(
28+
SqlExpression? instance,
29+
MethodInfo method,
30+
IReadOnlyList<SqlExpression> arguments,
31+
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
32+
{
33+
if (method.DeclaringType == typeof(SqlServerDbFunctionsExtensions))
34+
{
35+
switch (method.Name)
36+
{
37+
case nameof(SqlServerDbFunctionsExtensions.VectorDistance)
38+
when arguments is [_, var distanceMetric, var vector1, var vector2]:
39+
{
40+
var vectorTypeMapping = vector1.TypeMapping ?? vector2.TypeMapping
41+
?? throw new InvalidOperationException(
42+
"One of the arguments to EF.Functions.VectorDistance must be a vector column.");
43+
44+
return sqlExpressionFactory.Function(
45+
"VECTOR_DISTANCE",
46+
[
47+
sqlExpressionFactory.ApplyTypeMapping(distanceMetric, typeMappingSource.FindMapping("varchar(max)")),
48+
sqlExpressionFactory.ApplyTypeMapping(vector1, vectorTypeMapping),
49+
sqlExpressionFactory.ApplyTypeMapping(vector2, vectorTypeMapping)
50+
],
51+
nullable: true,
52+
argumentsPropagateNullability: [true, true, true],
53+
typeof(double),
54+
typeMappingSource.FindMapping(typeof(double)));
55+
}
56+
}
57+
}
58+
59+
return null;
60+
}
61+
62+
/// <summary>
63+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
64+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
65+
/// any release. You should only use it directly in your code with extreme caution and knowing that
66+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
67+
/// </summary>
68+
public SqlExpression? Translate(
69+
SqlExpression? instance,
70+
MemberInfo member,
71+
Type returnType,
72+
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
73+
{
74+
if (member.DeclaringType == typeof(SqlVector<float>))
75+
{
76+
switch (member.Name)
77+
{
78+
case nameof(SqlVector<>.Length) when instance is not null:
79+
{
80+
return sqlExpressionFactory.Function(
81+
"VECTORPROPERTY",
82+
[
83+
instance,
84+
sqlExpressionFactory.Constant("Dimensions", typeMappingSource.FindMapping("varchar(max)"))
85+
],
86+
nullable: true,
87+
argumentsPropagateNullability: [true, true],
88+
typeof(int),
89+
typeMappingSource.FindMapping(typeof(int)));
90+
}
91+
}
92+
}
93+
94+
return null;
95+
}
96+
}
97+

src/EFCore.SqlServer/Scaffolding/Internal/SqlServerDatabaseModelFactory.cs

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ FROM [sys].[types] AS [t]
395395
var precision = reader.GetValueOrDefault<int>("precision");
396396
var scale = reader.GetValueOrDefault<int>("scale");
397397

398-
var storeType = GetStoreType(systemType, maxLength, precision, scale);
398+
var storeType = GetStoreType(systemType, maxLength, precision, scale, vectorDimensions: 0);
399399

400400
_logger.TypeAliasFound(DisplayName(schema, userType), storeType);
401401

@@ -472,7 +472,7 @@ FROM [sys].[sequences] AS [s]
472472
storeType = value.storeType;
473473
}
474474

475-
storeType = GetStoreType(storeType, maxLength: 0, precision: precision, scale: scale);
475+
storeType = GetStoreType(storeType, maxLength: 0, precision, scale, vectorDimensions: 0);
476476

477477
_logger.SequenceFound(DisplayName(schema, name), storeType, cyclic, incrementBy, startValue, minValue, maxValue);
478478

@@ -730,6 +730,7 @@ private void GetColumns(
730730
CAST([c].[max_length] AS int) AS [max_length],
731731
CAST([c].[precision] AS int) AS [precision],
732732
CAST([c].[scale] AS int) AS [scale],
733+
{(_compatibilityLevel is >= 170 ? "[c].[vector_dimensions]" : "NULL as [vector_dimensions]")},
733734
[c].[is_nullable],
734735
[c].[is_identity],
735736
[dc].[definition] AS [default_sql],
@@ -801,6 +802,7 @@ FROM [sys].[views] v
801802
var maxLength = dataRecord.GetValueOrDefault<int>("max_length");
802803
var precision = dataRecord.GetValueOrDefault<int>("precision");
803804
var scale = dataRecord.GetValueOrDefault<int>("scale");
805+
var vectorDimensions = dataRecord.GetValueOrDefault<int>("vector_dimensions");
804806
var nullable = dataRecord.GetValueOrDefault<bool>("is_nullable");
805807
var isIdentity = dataRecord.GetValueOrDefault<bool>("is_identity");
806808
var defaultValueSql = dataRecord.GetValueOrDefault<string>("default_sql");
@@ -835,15 +837,19 @@ FROM [sys].[views] v
835837
string storeType;
836838
string systemTypeName;
837839

838-
// Swap store type if type alias is used
839-
if (typeAliases.TryGetValue($"[{dataTypeSchemaName}].[{dataTypeName}]", out var value))
840+
// If the store type is in our loaded aliases dictionary, resolve to the canonical type.
841+
// Note that the vector type is implemented as an alias for varbinary, but we do not want
842+
// to scaffold vectors as varbinary.
843+
var fullQualifiedTypeName = $"[{dataTypeSchemaName}].[{dataTypeName}]";
844+
if (fullQualifiedTypeName is not "[sys].[vector]"
845+
&& typeAliases.TryGetValue(fullQualifiedTypeName, out var value))
840846
{
841847
storeType = value.storeType;
842848
systemTypeName = value.typeName;
843849
}
844850
else
845851
{
846-
storeType = GetStoreType(dataTypeName, maxLength, precision, scale);
852+
storeType = GetStoreType(dataTypeName, maxLength, precision, scale, vectorDimensions);
847853
systemTypeName = dataTypeName;
848854
}
849855

@@ -995,16 +1001,16 @@ void Unwrap()
9951001
}
9961002
}
9971003

998-
private static string GetStoreType(string dataTypeName, int maxLength, int precision, int scale)
1004+
private static string GetStoreType(string dataTypeName, int maxLength, int precision, int scale, int vectorDimensions)
9991005
{
1000-
if (dataTypeName == "timestamp")
1006+
switch (dataTypeName)
10011007
{
1002-
return "rowversion";
1003-
}
1004-
1005-
if (dataTypeName is "decimal" or "numeric")
1006-
{
1007-
return $"{dataTypeName}({precision}, {scale})";
1008+
case "timestamp":
1009+
return "rowversion";
1010+
case "decimal" or "numeric":
1011+
return $"{dataTypeName}({precision}, {scale})";
1012+
case "vector":
1013+
return $"vector({vectorDimensions})";
10081014
}
10091015

10101016
if (DateTimePrecisionTypes.Contains(dataTypeName)

0 commit comments

Comments
 (0)