Skip to content

Commit

Permalink
[release/9.0] Fix Contains on ImmutableArray (#35251)
Browse files Browse the repository at this point in the history
  • Loading branch information
cincuranet authored Dec 2, 2024
1 parent 507152b commit 6489581
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ public class QueryableMethodNormalizingExpressionVisitor : ExpressionVisitor
private readonly SelectManyVerifyingExpressionVisitor _selectManyVerifyingExpressionVisitor = new();
private readonly GroupJoinConvertingExpressionVisitor _groupJoinConvertingExpressionVisitor = new();

private static readonly bool UseOldBehavior35102 =
AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue35102", out var enabled35102) && enabled35102;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -489,12 +492,16 @@ private Expression TryConvertCollectionContainsToQueryableContains(MethodCallExp

var sourceType = methodCallExpression.Method.DeclaringType!.GetGenericArguments()[0];

var objectExpression = methodCallExpression.Object!.Type.IsValueType && !UseOldBehavior35102
? Expression.Convert(methodCallExpression.Object!, typeof(IEnumerable<>).MakeGenericType(sourceType))
: methodCallExpression.Object!;

return VisitMethodCall(
Expression.Call(
QueryableMethods.Contains.MakeGenericMethod(sourceType),
Expression.Call(
QueryableMethods.AsQueryable.MakeGenericMethod(sourceType),
methodCallExpression.Object!),
objectExpression),
methodCallExpression.Arguments[0]));
}

Expand Down
19 changes: 18 additions & 1 deletion src/EFCore/Query/QueryRootProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ public class QueryRootProcessor : ExpressionVisitor
{
private readonly QueryCompilationContext _queryCompilationContext;

private static readonly bool UseOldBehavior35102 =
AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue35102", out var enabled35102) && enabled35102;

/// <summary>
/// Creates a new instance of the <see cref="QueryRootProcessor" /> class with associated query provider.
/// </summary>
Expand Down Expand Up @@ -85,7 +88,21 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp

private Expression VisitQueryRootCandidate(Expression expression, Type elementClrType)
{
switch (expression)
var candidateExpression = expression;

if (!UseOldBehavior35102)
{
// In case the collection was value type, in order to call methods like AsQueryable,
// we need to convert it to IEnumerable<T> which requires boxing.
// We do that with Convert expression which we need to unwrap here.
if (expression is UnaryExpression { NodeType: ExpressionType.Convert } convertExpression
&& convertExpression.Type.GetGenericTypeDefinition() == typeof(IEnumerable<>))
{
candidateExpression = convertExpression.Operand;
}
}

switch (candidateExpression)
{
// An array containing only constants is represented as a ConstantExpression with the array as the value.
// Convert that into a NewArrayExpression for use with InlineQueryRootExpression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,30 @@ WHERE ARRAY_CONTAINS(@__ints_0, c["Int"])
"""
@__ints_0='[10,999]'
SELECT VALUE c
FROM root c
WHERE NOT(ARRAY_CONTAINS(@__ints_0, c["Int"]))
""");
});

public override Task Parameter_collection_ImmutableArray_of_ints_Contains_int(bool async)
=> CosmosTestHelpers.Instance.NoSyncTest(
async, async a =>
{
await base.Parameter_collection_ImmutableArray_of_ints_Contains_int(a);

AssertSql(
"""
@__ints_0='[10,999]'
SELECT VALUE c
FROM root c
WHERE ARRAY_CONTAINS(@__ints_0, c["Int"])
""",
//
"""
@__ints_0='[10,999]'
SELECT VALUE c
FROM root c
WHERE NOT(ARRAY_CONTAINS(@__ints_0, c["Int"]))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Immutable;

namespace Microsoft.EntityFrameworkCore.Query;

public abstract class PrimitiveCollectionsQueryTestBase<TFixture>(TFixture fixture) : QueryTestBase<TFixture>(fixture)
Expand Down Expand Up @@ -363,6 +365,20 @@ await AssertQuery(
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => !ints.Contains(c.Int)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Parameter_collection_ImmutableArray_of_ints_Contains_int(bool async)
{
var ints = ImmutableArray.Create([10, 999]);

await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => ints.Contains(c.Int)));
await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => !ints.Contains(c.Int)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Parameter_collection_of_ints_Contains_nullable_int(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,24 @@ WHERE [p].[Int] NOT IN (10, 999)
""");
}

public override async Task Parameter_collection_ImmutableArray_of_ints_Contains_int(bool async)
{
await base.Parameter_collection_ImmutableArray_of_ints_Contains_int(async);

AssertSql(
"""
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] IN (10, 999)
""",
//
"""
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] NOT IN (10, 999)
""");
}

public override async Task Parameter_collection_of_ints_Contains_nullable_int(bool async)
{
await base.Parameter_collection_of_ints_Contains_nullable_int(async);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,34 @@ FROM OPENJSON(@__ints_0) WITH ([value] int '$') AS [i]
"""
@__ints_0='[10,999]' (Size = 4000)
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] NOT IN (
SELECT [i].[value]
FROM OPENJSON(@__ints_0) WITH ([value] int '$') AS [i]
)
""");
}

public override async Task Parameter_collection_ImmutableArray_of_ints_Contains_int(bool async)
{
await base.Parameter_collection_ImmutableArray_of_ints_Contains_int(async);

AssertSql(
"""
@__ints_0='[10,999]' (Nullable = false) (Size = 4000)
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] IN (
SELECT [i].[value]
FROM OPENJSON(@__ints_0) WITH ([value] int '$') AS [i]
)
""",
//
"""
@__ints_0='[10,999]' (Nullable = false) (Size = 4000)
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] NOT IN (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,34 @@ FROM OPENJSON(@__ints_0) WITH ([value] int '$') AS [i]
"""
@__ints_0='[10,999]' (Size = 4000)
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] NOT IN (
SELECT [i].[value]
FROM OPENJSON(@__ints_0) WITH ([value] int '$') AS [i]
)
""");
}

public override async Task Parameter_collection_ImmutableArray_of_ints_Contains_int(bool async)
{
await base.Parameter_collection_ImmutableArray_of_ints_Contains_int(async);

AssertSql(
"""
@__ints_0='[10,999]' (Nullable = false) (Size = 4000)
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] IN (
SELECT [i].[value]
FROM OPENJSON(@__ints_0) WITH ([value] int '$') AS [i]
)
""",
//
"""
@__ints_0='[10,999]' (Nullable = false) (Size = 4000)
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] NOT IN (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,34 @@ FROM OPENJSON(@__ints_0) WITH ([value] int '$') AS [i]
"""
@__ints_0='[10,999]' (Size = 4000)
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] NOT IN (
SELECT [i].[value]
FROM OPENJSON(@__ints_0) WITH ([value] int '$') AS [i]
)
""");
}

public override async Task Parameter_collection_ImmutableArray_of_ints_Contains_int(bool async)
{
await base.Parameter_collection_ImmutableArray_of_ints_Contains_int(async);

AssertSql(
"""
@__ints_0='[10,999]' (Nullable = false) (Size = 4000)
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] IN (
SELECT [i].[value]
FROM OPENJSON(@__ints_0) WITH ([value] int '$') AS [i]
)
""",
//
"""
@__ints_0='[10,999]' (Nullable = false) (Size = 4000)
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] NOT IN (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,34 @@ FROM json_each(@__ints_0) AS "i"
"""
@__ints_0='[10,999]' (Size = 8)
SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."String", "p"."Strings"
FROM "PrimitiveCollectionsEntity" AS "p"
WHERE "p"."Int" NOT IN (
SELECT "i"."value"
FROM json_each(@__ints_0) AS "i"
)
""");
}

public override async Task Parameter_collection_ImmutableArray_of_ints_Contains_int(bool async)
{
await base.Parameter_collection_ImmutableArray_of_ints_Contains_int(async);

AssertSql(
"""
@__ints_0='[10,999]' (Nullable = false) (Size = 8)
SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."String", "p"."Strings"
FROM "PrimitiveCollectionsEntity" AS "p"
WHERE "p"."Int" IN (
SELECT "i"."value"
FROM json_each(@__ints_0) AS "i"
)
""",
//
"""
@__ints_0='[10,999]' (Nullable = false) (Size = 8)
SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."String", "p"."Strings"
FROM "PrimitiveCollectionsEntity" AS "p"
WHERE "p"."Int" NOT IN (
Expand Down

0 comments on commit 6489581

Please sign in to comment.