diff --git a/global.json b/global.json index 20f482a..f15a959 100644 --- a/global.json +++ b/global.json @@ -1,5 +1,6 @@ { "sdk": { - "version": "9.0.100" + "version": "9.0.100", + "rollForward": "latestMinor" } } \ No newline at end of file diff --git a/src/EntityFrameworkCore.Projectables/Extensions/ExpressionExtensions.cs b/src/EntityFrameworkCore.Projectables/Extensions/ExpressionExtensions.cs index 49da220..43116a5 100644 --- a/src/EntityFrameworkCore.Projectables/Extensions/ExpressionExtensions.cs +++ b/src/EntityFrameworkCore.Projectables/Extensions/ExpressionExtensions.cs @@ -18,6 +18,6 @@ public static Expression ExpandQuaryables(this Expression expression) /// Replaces all calls to properties and methods that are marked with the Projectable attribute with their respective expression tree /// public static Expression ExpandProjectables(this Expression expression) - => new ProjectableExpressionReplacer(new ProjectionExpressionResolver()).Replace(expression); + => new ProjectableExpressionReplacer(new ProjectionExpressionResolver(), false).Replace(expression); } } diff --git a/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/CustomQueryCompiler.cs b/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/CustomQueryCompiler.cs index 5f66799..01bd0d0 100644 --- a/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/CustomQueryCompiler.cs +++ b/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/CustomQueryCompiler.cs @@ -26,15 +26,34 @@ public sealed class CustomQueryCompiler : QueryCompiler readonly IQueryCompiler _decoratedQueryCompiler; readonly ProjectableExpressionReplacer _projectableExpressionReplacer; - public CustomQueryCompiler(IQueryCompiler decoratedQueryCompiler, IQueryContextFactory queryContextFactory, ICompiledQueryCache compiledQueryCache, ICompiledQueryCacheKeyGenerator compiledQueryCacheKeyGenerator, IDatabase database, IDiagnosticsLogger logger, ICurrentDbContext currentContext, IEvaluatableExpressionFilter evaluatableExpressionFilter, IModel model) : base(queryContextFactory, compiledQueryCache, compiledQueryCacheKeyGenerator, database, logger, currentContext, evaluatableExpressionFilter, model) + public CustomQueryCompiler(IQueryCompiler decoratedQueryCompiler, + IQueryContextFactory queryContextFactory, + ICompiledQueryCache compiledQueryCache, + ICompiledQueryCacheKeyGenerator compiledQueryCacheKeyGenerator, + IDatabase database, + IDbContextOptions contextOptions, + IDiagnosticsLogger logger, + ICurrentDbContext currentContext, + IEvaluatableExpressionFilter evaluatableExpressionFilter, + IModel model) : base(queryContextFactory, + compiledQueryCache, + compiledQueryCacheKeyGenerator, + database, + logger, + currentContext, + evaluatableExpressionFilter, + model) { _decoratedQueryCompiler = decoratedQueryCompiler; - _projectableExpressionReplacer = new ProjectableExpressionReplacer(new ProjectionExpressionResolver()); + var trackingByDefault = (contextOptions.FindExtension()?.QueryTrackingBehavior ?? QueryTrackingBehavior.TrackAll) == + QueryTrackingBehavior.TrackAll; + + _projectableExpressionReplacer = new ProjectableExpressionReplacer(new ProjectionExpressionResolver(), trackingByDefault); } - public override Func CreateCompiledAsyncQuery(Expression query) + public override Func CreateCompiledAsyncQuery(Expression query) => _decoratedQueryCompiler.CreateCompiledAsyncQuery(Expand(query)); - public override Func CreateCompiledQuery(Expression query) + public override Func CreateCompiledQuery(Expression query) => _decoratedQueryCompiler.CreateCompiledQuery(Expand(query)); public override TResult Execute(Expression query) => _decoratedQueryCompiler.Execute(Expand(query)); diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs index c334b4e..8f9f4a1 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs @@ -4,6 +4,7 @@ using System.Reflection; using System.Runtime.CompilerServices; using EntityFrameworkCore.Projectables.Extensions; +using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Query; @@ -15,14 +16,16 @@ public sealed class ProjectableExpressionReplacer : ExpressionVisitor private readonly ExpressionArgumentReplacer _expressionArgumentReplacer = new(); private readonly Dictionary _projectableMemberCache = new(); private IQueryProvider? _currentQueryProvider; - private bool _disableRootRewrite; + private bool _disableRootRewrite = false; + private readonly bool _trackingByDefault; private IEntityType? _entityType; private readonly MethodInfo _select; private readonly MethodInfo _where; - public ProjectableExpressionReplacer(IProjectionExpressionResolver projectionExpressionResolver) + public ProjectableExpressionReplacer(IProjectionExpressionResolver projectionExpressionResolver, bool trackByDefault = false) { + _trackingByDefault = trackByDefault; _resolver = projectionExpressionResolver; _select = typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public) .Where(x => x.Name == nameof(Queryable.Select)) @@ -59,7 +62,7 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La [return: NotNullIfNotNull(nameof(node))] public Expression? Replace(Expression? node) { - _disableRootRewrite = false; + _disableRootRewrite = _trackingByDefault; _currentQueryProvider = null; _entityType = null; @@ -163,6 +166,15 @@ protected override Expression VisitMethodCall(MethodCallExpression node) _disableRootRewrite = true; } + if (methodInfo.Name == nameof(EntityFrameworkQueryableExtensions.AsTracking)) + { + _disableRootRewrite = true; + } + if (methodInfo.Name is nameof(EntityFrameworkQueryableExtensions.AsNoTracking) or nameof(EntityFrameworkQueryableExtensions.AsNoTrackingWithIdentityResolution)) + { + _disableRootRewrite = false; + } + if (TryGetReflectedExpression(methodInfo, out var reflectedExpression)) { for (var parameterIndex = 0; parameterIndex < reflectedExpression.Parameters.Count; parameterIndex++) diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/ChangeTrackerTests.cs b/tests/EntityFrameworkCore.Projectables.FunctionalTests/ChangeTrackerTests.cs new file mode 100644 index 0000000..41bddf4 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/ChangeTrackerTests.cs @@ -0,0 +1,76 @@ +using System.Linq; +using System.Threading.Tasks; +using EntityFrameworkCore.Projectables.FunctionalTests.Helpers; +using EntityFrameworkCore.Projectables.Infrastructure; +using Microsoft.EntityFrameworkCore; +using Xunit; + +namespace EntityFrameworkCore.Projectables.FunctionalTests; + +public class ChangeTrackerTests +{ + public class SqliteSampleDbContext : DbContext + where TEntity : class + { + protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) + { + optionsBuilder.UseSqlite("Data Source=test.sqlite"); + optionsBuilder.UseProjectables(); + } + + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.Entity(); + } + } + + public record Entity + { + private static int _nextId = 1; + public const int Computed1DefaultValue = -1; + public int Id { get; set; } = _nextId++; + public string? Name { get; set; } + + [Projectable(UseMemberBody = nameof(InternalComputed1))] + public int Computed1 { get; set; } = Computed1DefaultValue; + private int InternalComputed1 => Id; + + [Projectable] + public int Computed2 => Id * 2; + } + + [Fact] + public async Task CanQueryAndChangeTrackedEntities() + { + using var dbContext = new SqliteSampleDbContext(); + await dbContext.Database.EnsureDeletedAsync(); + await dbContext.Database.EnsureCreatedAsync(); + dbContext.Add(new Entity()); + await dbContext.SaveChangesAsync(); + dbContext.ChangeTracker.Clear(); + + var entity = await dbContext.Set().AsTracking().FirstAsync(); + var entityEntry = dbContext.ChangeTracker.Entries().Single(); + Assert.Same(entityEntry.Entity, entity); + dbContext.Set().Remove(entity); + await dbContext.SaveChangesAsync(); + } + + [Fact] + public async Task CanSaveChanges() + { + using var dbContext = new SqliteSampleDbContext(); + await dbContext.Database.EnsureDeletedAsync(); + await dbContext.Database.EnsureCreatedAsync(); + dbContext.Add(new Entity()); + await dbContext.SaveChangesAsync(); + dbContext.ChangeTracker.Clear(); + + var entity = await dbContext.Set().AsTracking().FirstAsync(); + entity.Name = "test"; + await dbContext.SaveChangesAsync(); + dbContext.ChangeTracker.Clear(); + var entity2 = await dbContext.Set().FirstAsync(); + Assert.Equal("test", entity2.Name); + } +} \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/EntityFrameworkCore.Projectables.FunctionalTests.csproj b/tests/EntityFrameworkCore.Projectables.FunctionalTests/EntityFrameworkCore.Projectables.FunctionalTests.csproj index 96c4d7a..801ccf7 100644 --- a/tests/EntityFrameworkCore.Projectables.FunctionalTests/EntityFrameworkCore.Projectables.FunctionalTests.csproj +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/EntityFrameworkCore.Projectables.FunctionalTests.csproj @@ -11,7 +11,8 @@ runtime; build; native; contentfiles; analyzers; buildtransitive - + + diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/Helpers/SampleDbContext.cs b/tests/EntityFrameworkCore.Projectables.FunctionalTests/Helpers/SampleDbContext.cs index 94908ed..f17e81b 100644 --- a/tests/EntityFrameworkCore.Projectables.FunctionalTests/Helpers/SampleDbContext.cs +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/Helpers/SampleDbContext.cs @@ -14,10 +14,12 @@ public class SampleDbContext : DbContext where TEntity : class { readonly CompatibilityMode _compatibilityMode; + readonly QueryTrackingBehavior _queryTrackingBehavior; - public SampleDbContext(CompatibilityMode compatibilityMode = CompatibilityMode.Full) + public SampleDbContext(CompatibilityMode compatibilityMode = CompatibilityMode.Full, QueryTrackingBehavior queryTrackingBehavior = QueryTrackingBehavior.TrackAll) { _compatibilityMode = compatibilityMode; + _queryTrackingBehavior = queryTrackingBehavior; } protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) @@ -26,6 +28,7 @@ protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) optionsBuilder.UseProjectables(options => { options.CompatibilityMode(_compatibilityMode); // Needed by our ComplexModelTests }); + optionsBuilder.UseQueryTrackingBehavior(_queryTrackingBehavior); } protected override void OnModelCreating(ModelBuilder modelBuilder) diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.AsNoTrackingQueryRootExpression.DotNet9_0.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.AsNoTrackingQueryRootExpression.DotNet9_0.verified.txt new file mode 100644 index 0000000..3b8b9ae --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.AsNoTrackingQueryRootExpression.DotNet9_0.verified.txt @@ -0,0 +1,2 @@ +SELECT [e].[Id], [e].[Id] * 5 +FROM [Entity] AS [e] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.AsNoTrackingQueryRootExpression.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.AsNoTrackingQueryRootExpression.verified.txt new file mode 100644 index 0000000..3b8b9ae --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.AsNoTrackingQueryRootExpression.verified.txt @@ -0,0 +1,2 @@ +SELECT [e].[Id], [e].[Id] * 5 +FROM [Entity] AS [e] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.AsTrackingQueryRootExpression.DotNet9_0.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.AsTrackingQueryRootExpression.DotNet9_0.verified.txt new file mode 100644 index 0000000..b1c3b32 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.AsTrackingQueryRootExpression.DotNet9_0.verified.txt @@ -0,0 +1,2 @@ +SELECT [e].[Id] +FROM [Entity] AS [e] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.AsTrackingQueryRootExpression.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.AsTrackingQueryRootExpression.verified.txt new file mode 100644 index 0000000..b1c3b32 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.AsTrackingQueryRootExpression.verified.txt @@ -0,0 +1,2 @@ +SELECT [e].[Id] +FROM [Entity] AS [e] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.DontUseMemberPropertyQueryRootExpression.DotNet9_0.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.DontUseMemberPropertyQueryRootExpression.DotNet9_0.verified.txt new file mode 100644 index 0000000..b1c3b32 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.DontUseMemberPropertyQueryRootExpression.DotNet9_0.verified.txt @@ -0,0 +1,2 @@ +SELECT [e].[Id] +FROM [Entity] AS [e] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.DontUseMemberPropertyQueryRootExpression.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.DontUseMemberPropertyQueryRootExpression.verified.txt new file mode 100644 index 0000000..b1c3b32 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.DontUseMemberPropertyQueryRootExpression.verified.txt @@ -0,0 +1,2 @@ +SELECT [e].[Id] +FROM [Entity] AS [e] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.cs b/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.cs index 3baadfd..f46289e 100644 --- a/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.cs +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.cs @@ -26,7 +26,17 @@ public record Entity [Fact] public Task UseMemberPropertyQueryRootExpression() { - using var dbContext = new SampleDbContext(); + using var dbContext = new SampleDbContext(queryTrackingBehavior: QueryTrackingBehavior.NoTracking); + + var query = dbContext.Set(); + + return Verifier.Verify(query.ToQueryString()); + } + + [Fact] + public Task DontUseMemberPropertyQueryRootExpression() + { + using var dbContext = new SampleDbContext(queryTrackingBehavior: QueryTrackingBehavior.TrackAll); var query = dbContext.Set(); @@ -47,5 +57,25 @@ public Task EntityRootSubqueryExpression() return Verifier.Verify(query.ToQueryString()); } + + [Fact] + public Task AsTrackingQueryRootExpression() + { + using var dbContext = new SampleDbContext(queryTrackingBehavior: QueryTrackingBehavior.NoTracking); + + var query = dbContext.Set().AsTracking(); + + return Verifier.Verify(query.ToQueryString()); + } + + [Fact] + public Task AsNoTrackingQueryRootExpression() + { + using var dbContext = new SampleDbContext(queryTrackingBehavior: QueryTrackingBehavior.TrackAll); + + var query = dbContext.Set().AsNoTracking(); + + return Verifier.Verify(query.ToQueryString()); + } } }