Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix non-lazy conditional evaluation #4533

Merged
merged 17 commits into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ protected ClickHouseDataProvider(string name, ClickHouseProvider provider)
// as emulation doesn't work properly due to missing rowcount functionality
SqlProviderFlags.IsInsertOrUpdateSupported = true;

SqlProviderFlags.IsUpdateFromSupported = false;
SqlProviderFlags.IsCommonTableExpressionsSupported = true;
SqlProviderFlags.IsSubQueryOrderBySupported = true;
SqlProviderFlags.DoesNotSupportCorrelatedSubquery = true;
SqlProviderFlags.IsAllSetOperationsSupported = true;
SqlProviderFlags.IsNestedJoinsSupported = false;
SqlProviderFlags.IsUpdateFromSupported = false;
SqlProviderFlags.IsCommonTableExpressionsSupported = true;
SqlProviderFlags.IsSubQueryOrderBySupported = true;
SqlProviderFlags.DoesNotSupportCorrelatedSubquery = true;
SqlProviderFlags.IsAllSetOperationsSupported = true;
SqlProviderFlags.IsNestedJoinsSupported = false;
SqlProviderFlags.IsSupportedSimpleCorrelatedSubqueries = true;

// unconfigured flags
// 1. ClickHouse doesn't support correlated subqueries at all so this flag's value doesn't make difference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,18 @@ protected override ISqlExpression ConvertConversion(SqlCastExpression cast)
return base.ConvertConversion(cast);
}

protected override ISqlExpression WrapBooleanExpression(ISqlExpression expr)
protected override ISqlExpression WrapColumnExpression(ISqlExpression expr)
{
var newExpr = base.WrapBooleanExpression(expr);
if (!ReferenceEquals(newExpr, expr))
var columnExpression = base.WrapColumnExpression(expr);

if (SqlProviderFlags != null && columnExpression.SystemType == typeof(bool) && QueryHelper.UnwrapNullablity(columnExpression) is not (SqlCastExpression or SqlColumn or SqlField))
sdanyliv marked this conversation as resolved.
Show resolved Hide resolved
{
return new SqlCastExpression(newExpr, new DbDataType(expr.SystemType ?? typeof(bool), DataType.Boolean), null, isMandatory : true);
columnExpression = new SqlCastExpression(columnExpression, new DbDataType(columnExpression.SystemType!, DataType.Boolean), null, isMandatory: true);
}

return newExpr;
return columnExpression;
}


sdanyliv marked this conversation as resolved.
Show resolved Hide resolved
}
}
2 changes: 1 addition & 1 deletion Source/LinqToDB/Linq/Builder/DefaultIfEmptyBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ protected override BuildSequenceResult BuildMethodCall(ExpressionBuilder builder

if (!buildInfo.IsSubQuery)
{
if (!SequenceHelper.IsSupportedSubquery(resultSelectContext, resultSelectContext, out var errorMessage))
if (!builder.IsSupportedSubquery(resultSelectContext, resultSelectContext, out var errorMessage))
return BuildSequenceResult.Error(methodCall, errorMessage);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ public Expression TryCreateAssociation(Expression expression, ContextRefExpressi

if (!flags.IsTest())
{
if (!SequenceHelper.IsSupportedSubquery(rootContext.BuildContext, sequence, out var errorMessage))
if (!IsSupportedSubquery(rootContext.BuildContext, sequence, out var errorMessage))
return new SqlErrorExpression(null, expression, errorMessage, expression.Type, true);
}

Expand Down
12 changes: 11 additions & 1 deletion Source/LinqToDB/Linq/Builder/ExpressionBuilder.EagerLoad.cs
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,11 @@ Expression ProcessEagerLoadingExpression(
_associations = saveAssociationsCache;
}

if (resultExpression is SqlErrorExpression errorExpression)
{
return errorExpression.WithType(eagerLoad.Type);
}

resultExpression = SqlAdjustTypeExpression.AdjustType(resultExpression, eagerLoad.Type, MappingSchema);

return resultExpression;
Expand Down Expand Up @@ -433,7 +438,8 @@ Expression BuildPreambleQueryAttached<TKey, T>(

query.Init(sequence, _parametersContext.CurrentSqlParameters);

BuildQuery(query, sequence, queryParameter, ref preambles!, previousKeys);
if (!BuildQuery(query, sequence, queryParameter, ref preambles!, previousKeys))
return query.ErrorExpression!;

var idx = preambles.Count;
var preamble = new Preamble<TKey, T>(query);
Expand Down Expand Up @@ -493,6 +499,10 @@ Expression CompleteEagerLoadingExpressions(
{
if (e.NodeType == ExpressionType.Extension && e is SqlEagerLoadExpression eagerLoad)
{
// Do not process eager loading fast mode
if (!_validateSubqueries)
return SqlErrorExpression.EnsureError(eagerLoad.SequenceExpression, e.Type);

eagerLoadingCache ??= new Dictionary<Expression, Expression>(ExpressionEqualityComparer.Instance);
if (!eagerLoadingCache.TryGetValue(eagerLoad.SequenceExpression, out var preambleExpression))
{
Expand Down
63 changes: 32 additions & 31 deletions Source/LinqToDB/Linq/Builder/ExpressionBuilder.Expressions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -860,51 +860,52 @@ public override Expression VisitPlaceholderExpression(PlaceholderExpression node

protected override Expression VisitConditional(ConditionalExpression node)
{
if (_flags.IsSql())
var saveFlags = _flags;

_flags |= ProjectFlags.ForceOuterAssociation;
try
{
var translated = TranslateExpression(node);
var translated = TranslateExpression(node, useSql:true);
sdanyliv marked this conversation as resolved.
Show resolved Hide resolved

if (translated is SqlPlaceholderExpression)
return translated;
}

if (IsForcedToConvert(node))
{
return TranslateExpression(node);
}
if (IsForcedToConvert(node))
{
return TranslateExpression(node);
}

var saveFlags = _flags;
var saveDescriptor = _columnDescriptor;
_columnDescriptor = null;
var test = Visit(node.Test);
_columnDescriptor = saveDescriptor;

_flags |= ProjectFlags.ForceOuterAssociation;
var ifTrue = Visit(node.IfTrue);
var ifFalse = Visit(node.IfFalse);

var saveDescriptor = _columnDescriptor;
_columnDescriptor = null;
var test = Visit(node.Test);
_columnDescriptor = saveDescriptor;
if (test is ConstantExpression { Value: bool boolValue })
{
return boolValue ? ifTrue : ifFalse;
}

var ifTrue = Visit(node.IfTrue);
var ifFalse = Visit(node.IfFalse);
if (ifTrue is SqlGenericConstructorExpression && ifFalse is SqlPlaceholderExpression)
ifFalse = node.IfFalse;
else if (ifFalse is SqlGenericConstructorExpression && ifTrue is SqlPlaceholderExpression)
ifTrue = node.IfTrue;

_flags = saveFlags;
if (test is SqlPlaceholderExpression &&
ifTrue is SqlPlaceholderExpression &&
ifFalse is SqlPlaceholderExpression)
{
return TranslateExpression(node, useSql: true);
}

if (test is ConstantExpression { Value: bool boolValue })
{
return boolValue ? ifTrue : ifFalse;
return node.Update(test, ifTrue, ifFalse);
}

if (ifTrue is SqlGenericConstructorExpression && ifFalse is SqlPlaceholderExpression)
ifFalse = node.IfFalse;
else if (ifFalse is SqlGenericConstructorExpression && ifTrue is SqlPlaceholderExpression)
ifTrue = node.IfTrue;

if (test is SqlPlaceholderExpression &&
ifTrue is SqlPlaceholderExpression &&
ifFalse is SqlPlaceholderExpression)
finally
{
return TranslateExpression(node, useSql : true);
_flags = saveFlags;
}

return node.Update(test, ifTrue, ifFalse);
}

public override Expression VisitSqlDefaultIfEmptyExpression(SqlDefaultIfEmptyExpression node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,16 @@ Expression FinalizeProjection<T>(

// process eager loading queries
var correctedEager = CompleteEagerLoadingExpressions(postProcessed, context, queryParameter, ref preambles, previousKeys);

if (SequenceHelper.HasError(correctedEager))
return correctedEager;

if (!ExpressionEqualityComparer.Instance.Equals(correctedEager, postProcessed))
{
// convert all missed references
postProcessed = FinalizeConstructors(context, correctedEager, false);
}

SequenceHelper.EnsureNoErrors(postProcessed);

var withColumns = ToColumns(context, postProcessed);
return withColumns;
}
Expand Down
103 changes: 83 additions & 20 deletions Source/LinqToDB/Linq/Builder/ExpressionBuilder.SqlBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ namespace LinqToDB.Linq.Builder
using Common.Internal;
using Data;
using Extensions;
using Linq.Translation;
using Translation;
using LinqToDB.Expressions;
using Mapping;
using Reflection;
using SqlQuery;
using DataProvider;

partial class ExpressionBuilder
{
Expand Down Expand Up @@ -150,6 +151,74 @@ public void BuildSkip(IBuildContext sequence, ISqlExpression expr)

#region SubQueryToSql

/// <summary>
/// Checks that provider can handle limitation inside subquery. This function is tightly coupled with <see cref="SelectQueryOptimizerVisitor.OptimizeApply"/>
/// </summary>
/// <param name="context"></param>
/// <returns></returns>
public bool IsSupportedSubquery(IBuildContext parent, IBuildContext context, out string? errorMessage)
{
errorMessage = null;

if (!_validateSubqueries)
return true;

// No check during recursion. Cloning may fail
if (parent.Builder.IsRecursiveBuild)
return true;

if (!context.Builder.DataContext.SqlProviderFlags.IsApplyJoinSupported)
{
// We are trying to simulate what will be with query after optimizer's work
//
var cloningContext = new CloningContext();

var clonedParentContext = cloningContext.CloneContext(parent);
var clonedContext = cloningContext.CloneContext(context);

cloningContext.UpdateContextParents();

var expr = parent.Builder.MakeExpression(clonedContext, new ContextRefExpression(clonedContext.ElementType, clonedContext), ProjectFlags.SQL);

expr = parent.Builder.ToColumns(clonedParentContext, expr);

SqlJoinedTable? fakeJoin = null;

// add fake join there is no still reference
if (null == clonedParentContext.SelectQuery.Find(e => e is SelectQuery sc && sc == clonedContext.SelectQuery))
{
fakeJoin = clonedContext.SelectQuery.OuterApply().JoinedTable;

clonedParentContext.SelectQuery.From.Tables[0].Joins.Add(fakeJoin);
}

using var visitor = QueryHelper.SelectOptimizer.Allocate();

#if DEBUG

var sqlText = clonedParentContext.SelectQuery.ToDebugString();

#endif

var optimizedQuery = (SelectQuery)visitor.Value.Optimize(
root : clonedParentContext.SelectQuery,
rootElement : clonedParentContext.SelectQuery,
providerFlags : parent.Builder.DataContext.SqlProviderFlags,
removeWeakJoins : false,
dataOptions : parent.Builder.DataOptions,
mappingSchema: context.MappingSchema,
evaluationContext : new EvaluationContext()
sdanyliv marked this conversation as resolved.
Show resolved Hide resolved
);

if (!SqlProviderHelper.IsValidQuery(optimizedQuery, parentQuery: null, fakeJoin: fakeJoin, forColumn: false, parent.Builder.DataContext.SqlProviderFlags, out errorMessage))
sdanyliv marked this conversation as resolved.
Show resolved Hide resolved
{
return false;
}
}

return true;
}

int _gettingSubquery;

public IBuildContext? GetSubQuery(IBuildContext context, Expression expr, ProjectFlags flags, out bool isSequence, out string? errorMessage)
Expand All @@ -174,7 +243,10 @@ public void BuildSkip(IBuildContext sequence, ISqlExpression expr)
{
if (_gettingSubquery == 0)
{
if (!SequenceHelper.IsSupportedSubquery(context, buildResult.BuildContext, out errorMessage))
++_gettingSubquery;
var isSupported = IsSupportedSubquery(context, buildResult.BuildContext, out errorMessage);
--_gettingSubquery;
if (!isSupported)
return null;
}
}
Expand Down Expand Up @@ -856,6 +928,9 @@ Expression ConvertToSqlInternal(IBuildContext? context, Expression expression, P
if (e.Method == null && (e.IsLifted || e.Type == typeof(object)))
return placeholder;

if (e.Method == null && operandExpr is not SqlPlaceholderExpression)
return e;
Comment on lines 933 to +937
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would prefer to group these two checks as sub-checks of a single if (e.Method is null)


if (e.Type == typeof(bool) && e.Operand.Type == typeof(SqlBoolean))
return placeholder;

Expand Down Expand Up @@ -890,7 +965,7 @@ Expression ConvertToSqlInternal(IBuildContext? context, Expression expression, P
{
var e = (ConditionalExpression)expression;

var testExpr = ConvertToSqlExpr(context, e.Test, flags.TestFlag(), columnDescriptor : columnDescriptor, isPureExpression : isPureExpression);
var testExpr = ConvertToSqlExpr(context, e.Test, flags.TestFlag(), columnDescriptor : null, isPureExpression : isPureExpression);
var trueExpr = ConvertToSqlExpr(context, e.IfTrue, flags.TestFlag(), columnDescriptor : columnDescriptor, isPureExpression : isPureExpression);
var falseExpr = ConvertToSqlExpr(context, e.IfFalse, flags.TestFlag(), columnDescriptor : columnDescriptor, isPureExpression : isPureExpression);
sdanyliv marked this conversation as resolved.
Show resolved Hide resolved

Expand Down Expand Up @@ -2090,6 +2165,8 @@ Expression GenerateConstructorComparison(SqlGenericConstructorExpression leftCon

if (l is SqlValue lv && lv.Value == null || left.IsNullValue())
{
rightExpr = BuildSqlExpression(context, rightExpr, flags);

if (rightExpr is ConditionalExpression { Test: SqlPlaceholderExpression { Sql: SqlSearchCondition rightSearchCond } } && rightSearchCond.Predicates.Count == 1)
{
var rightPredicate = rightSearchCond.Predicates[0];
Expand All @@ -2104,13 +2181,13 @@ Expression GenerateConstructorComparison(SqlGenericConstructorExpression leftCon
}
}

rightExpr = BuildSqlExpression(context, rightExpr, flags);

return GenerateNullComparison(rightExpr, isNot);
}

if (r is SqlValue rv && rv.Value == null || right.IsNullValue())
{
leftExpr = BuildSqlExpression(context, leftExpr, flags);

if (leftExpr is ConditionalExpression { Test: SqlPlaceholderExpression { Sql: SqlSearchCondition leftSearchCond } } && leftSearchCond.Predicates.Count == 1)
{
var leftPredicate = leftSearchCond.Predicates[0];
Expand All @@ -2125,8 +2202,6 @@ Expression GenerateConstructorComparison(SqlGenericConstructorExpression leftCon
}
}

leftExpr = BuildSqlExpression(context, leftExpr, flags);

return GenerateNullComparison(leftExpr, isNot);
}

Expand Down Expand Up @@ -2380,18 +2455,6 @@ public bool CollectNullCompareExpressions(IBuildContext context, Expression expr
{
switch (expression.NodeType)
{
case ExpressionType.Conditional:
{
var cond = (ConditionalExpression)expression;

if (!CollectNullCompareExpressions(context, cond.IfTrue, result))
return false;
if (!CollectNullCompareExpressions(context, cond.IfFalse, result))
return false;

return true;
}

case ExpressionType.Constant:
case ExpressionType.Default:
{
Expand Down Expand Up @@ -4266,7 +4329,7 @@ public Expression Project(IBuildContext context, Expression? path, List<Expressi
return new DefaultValueExpression(MappingSchema, truePath.Type);
}

var falsePath = new DefaultValueExpression(MappingSchema, truePath.Type);
var falsePath = Expression.Constant(null, truePath.Type);

var conditional = Expression.Condition(isPredicate, truePath, falsePath);

Expand Down
Loading
Loading