Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,12 @@ public static void RegisterFunctions(IFunctionsRegister functionsRegister)
string methodName = nameof(Substring);
if (scalarFunction.Options != null && scalarFunction.Options.TryGetValue(NegativeStart, out var negativeStart))
{
if (negativeStart == WrapFromEnd)
var negativeStartValue = negativeStart.First();
if (negativeStartValue == WrapFromEnd)
{
methodName = nameof(SubstringWrapFromEnd);
}
else if (negativeStart == LeftOfBeginning)
else if (negativeStartValue == LeftOfBeginning)
{
methodName = nameof(SubstringLeftOfBeginning);
}
Expand All @@ -97,7 +98,7 @@ public static void RegisterFunctions(IFunctionsRegister functionsRegister)
{
bool ignoreNulls = false;

if(func.Options != null && func.Options.TryGetValue(NullHandling, out var nullHandling) && nullHandling == IgnoreNulls)
if(func.Options != null && func.Options.TryGetValue(NullHandling, out var nullHandling) && nullHandling.First() == IgnoreNulls)
{
ignoreNulls = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ private static void RegisterMethod(string extensionUri, string extensionName, IF
throw new NotSupportedException("Only one option is supported at this time.");
}

methodInformation = methodsInformation.FirstOrDefault(x => x.option == func.Options.Keys[0] && x.optionValue == func.Options.Values[0]);
methodInformation = methodsInformation.FirstOrDefault(x =>
{
var firstOption = func.Options.First();
return x.option == firstOption.Key && x.optionValue == firstOption.Value.First();
});
}
if (methodInformation == null)
{
Expand Down
2 changes: 1 addition & 1 deletion src/FlowtideDotNet.Substrait/Expressions/ScalarFunction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public sealed class ScalarFunction : Expression, IEquatable<ScalarFunction>

public required List<Expression> Arguments { get; set; }

public SortedList<string, string>? Options { get; set; }
public IReadOnlyDictionary<string, IReadOnlyList<string>>? Options { get; set; }

public override TOutput Accept<TOutput, TState>(ExpressionVisitor<TOutput, TState> visitor, TState state)
{
Expand Down
2 changes: 1 addition & 1 deletion src/FlowtideDotNet.Substrait/Sql/ISqlFunctionRegister.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public record AggregateResponse(AggregateFunction AggregateFunction, SubstraitBa

public interface ISqlFunctionRegister
{
void RegisterScalarFunction(string name, Func<SqlParser.Ast.Expression.Function, SqlExpressionVisitor, EmitData, ScalarResponse> mapFunc);
void RegisterScalarFunction(string name, Func<SqlParser.Ast.Expression.Function, IReadOnlyDictionary<string, string>, SqlExpressionVisitor, EmitData, ScalarResponse> mapFunc);

void RegisterAggregateFunction(string name, Func<SqlParser.Ast.Expression.Function, SqlExpressionVisitor, EmitData, AggregateResponse> mapFunc);

Expand Down
113 changes: 69 additions & 44 deletions src/FlowtideDotNet.Substrait/Sql/Internal/BuiltInSqlFunctions.cs

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ protected override EmitData VisitSelect(Select select, object? state)
int outputCounter = 0;
foreach (var s in selects)
{
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister);
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister, new Dictionary<string, SortedDictionary<string, string>>());
if (s is SelectItem.ExpressionWithAlias exprAlias)
{
SubstraitBaseType returnType = new AnyType();
Expand Down
9 changes: 9 additions & 0 deletions src/FlowtideDotNet.Substrait/Sql/Internal/SqlBaseVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ public virtual TReturn Visit(Statement statement, TState state)
{
return VisitBeginSubStream(beginSubStream);
}
if (statement is Statement.SetVariable setVariable)
{
return VisitSetVariable(setVariable);
}
throw new NotImplementedException();
}

protected virtual TReturn VisitSetVariable(Statement.SetVariable setVariable)
{
throw new NotImplementedException();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,23 @@ internal enum FunctionType

internal class SqlFunctionRegister : ISqlFunctionRegister
{
private readonly Dictionary<string, Func<SqlParser.Ast.Expression.Function, SqlExpressionVisitor, EmitData, ScalarResponse>> _scalarFunctions;
private readonly Dictionary<string, Func<SqlParser.Ast.Expression.Function, IReadOnlyDictionary<string, string>, SqlExpressionVisitor, EmitData, ScalarResponse>> _scalarFunctions;
private readonly Dictionary<string, Func<SqlParser.Ast.Expression.Function, SqlExpressionVisitor, EmitData, AggregateResponse>> _aggregateFunctions;
private readonly Dictionary<string, Func<SqlTableFunctionArgument, TableFunction>> _tableFunctions;

public SqlFunctionRegister()
{
_scalarFunctions = new Dictionary<string, Func<SqlParser.Ast.Expression.Function, SqlExpressionVisitor, EmitData, ScalarResponse>>(StringComparer.OrdinalIgnoreCase);
_scalarFunctions = new Dictionary<string, Func<SqlParser.Ast.Expression.Function, IReadOnlyDictionary<string, string>, SqlExpressionVisitor, EmitData, ScalarResponse>>(StringComparer.OrdinalIgnoreCase);
_aggregateFunctions = new Dictionary<string, Func<SqlParser.Ast.Expression.Function, SqlExpressionVisitor, EmitData, AggregateResponse>>(StringComparer.OrdinalIgnoreCase);
_tableFunctions = new Dictionary<string, Func<SqlTableFunctionArgument, TableFunction>>(StringComparer.OrdinalIgnoreCase);
}

public void RegisterScalarFunction(string name, Func<SqlParser.Ast.Expression.Function, SqlExpressionVisitor, EmitData, ScalarResponse> mapFunc)
public void RegisterScalarFunction(string name, Func<SqlParser.Ast.Expression.Function, IReadOnlyDictionary<string, string>, SqlExpressionVisitor, EmitData, ScalarResponse> mapFunc)
{
_scalarFunctions.Add(name, mapFunc);
}

public Func<SqlParser.Ast.Expression.Function, SqlExpressionVisitor, EmitData, ScalarResponse> GetScalarMapper(string name)
public Func<SqlParser.Ast.Expression.Function, IReadOnlyDictionary<string, string>, SqlExpressionVisitor, EmitData, ScalarResponse> GetScalarMapper(string name)
{
return _scalarFunctions[name];
}
Expand Down
68 changes: 53 additions & 15 deletions src/FlowtideDotNet.Substrait/Sql/Internal/SqlSubstraitVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ internal class SqlSubstraitVisitor : SqlBaseVisitor<RelationData?, object?>
private string? subStreamName;
private int exchangeTargetIdCounter;
private readonly List<Relation> subRelations;
private Dictionary<string, SortedDictionary<string, string>> globalFunctionOptions;

public SqlSubstraitVisitor(SqlPlanBuilder sqlPlanBuilder, SqlFunctionRegister sqlFunctionRegister)
{
Expand All @@ -44,6 +45,7 @@ public SqlSubstraitVisitor(SqlPlanBuilder sqlPlanBuilder, SqlFunctionRegister sq
exchangeRelations = new Dictionary<string, ExchangeContainer>(StringComparer.OrdinalIgnoreCase);
viewRelations = new Dictionary<string, ViewContainer>(StringComparer.OrdinalIgnoreCase);
subRelations = new List<Relation>();
globalFunctionOptions = new Dictionary<string, SortedDictionary<string, string>>(StringComparer.OrdinalIgnoreCase);
}

public List<Relation> GetRelations(Sequence<Statement> statements)
Expand Down Expand Up @@ -177,7 +179,7 @@ public List<Relation> GetRelations(Sequence<Statement> statements)
{
if (kvOption.Value is Expression.Identifier scatterIdentifier)
{
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister);
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister, globalFunctionOptions);
// Do a lookup on the partition by field
var exprData = exprVisitor.Visit(scatterIdentifier, relationData.EmitData);
if (exprData.Expr is Expressions.FieldReference fieldReference)
Expand Down Expand Up @@ -353,7 +355,7 @@ public List<Relation> GetRelations(Sequence<Statement> statements)
}
if (query.OrderBy != null && query.OrderBy.Expressions != null)
{
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister);
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister, globalFunctionOptions);
List<Expressions.SortField> sortFields = new List<Expressions.SortField>();
foreach (var o in query.OrderBy.Expressions)
{
Expand Down Expand Up @@ -476,7 +478,7 @@ private static Expressions.SortDirection GetSortDirection(OrderByExpression o)

if (select.Selection != null)
{
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister);
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister, globalFunctionOptions);
var expr = exprVisitor.Visit(select.Selection, outNode.EmitData);
outNode = new RelationData(new FilterRelation()
{
Expand Down Expand Up @@ -592,7 +594,7 @@ private RelationData VisitSelectAggregate(Select select, ContainsAggregateVisito
{
foreach (var group in groupByExpressions.ColumnNames)
{
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister);
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister, globalFunctionOptions);
var result = exprVisitor.Visit(group, parent.EmitData);
grouping.GroupingExpressions.Add(result.Expr);
aggEmitData.Add(group, emitcount, result.Name, result.Type);
Expand All @@ -608,7 +610,7 @@ private RelationData VisitSelectAggregate(Select select, ContainsAggregateVisito
foreach (var foundMeasure in containsAggregateVisitor.AggregateFunctions)
{
var mapper = sqlFunctionRegister.GetAggregateMapper(foundMeasure.Name);
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister);
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister, globalFunctionOptions);

var aggregateResponse = mapper(foundMeasure, exprVisitor, parent.EmitData);
aggRel.Measures.Add(new AggregateMeasure()
Expand All @@ -621,7 +623,7 @@ private RelationData VisitSelectAggregate(Select select, ContainsAggregateVisito

if (select.Having != null)
{
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister);
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister, globalFunctionOptions);
outputRelation = new FilterRelation()
{
Condition = exprVisitor.Visit(select.Having, aggEmitData).Expr,
Expand All @@ -640,7 +642,7 @@ private RelationData VisitSelectAggregate(Select select, ContainsAggregateVisito
int outputCounter = 0;
foreach (var s in selects)
{
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister);
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister, globalFunctionOptions);
if (s is SelectItem.ExpressionWithAlias exprAlias)
{
var condition = exprVisitor.Visit(exprAlias.Expression, emitData);
Expand Down Expand Up @@ -700,7 +702,7 @@ private RelationData VisitSelectAggregate(Select select, ContainsAggregateVisito
int outputCounter = 0;
foreach (var s in selects)
{
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister);
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister, globalFunctionOptions);
if (s is SelectItem.ExpressionWithAlias exprAlias)
{
var condition = exprVisitor.Visit(exprAlias.Expression, parent.EmitData);
Expand Down Expand Up @@ -817,7 +819,7 @@ private RelationData VisitTableFunctionRoot(TableFactor tableFactor)
{
GetTableFunctionNameAndArgs(tableFactor, out var name, out var args);
var tableFunctionMapper = sqlFunctionRegister.GetTableMapper(name);
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister);
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister, globalFunctionOptions);

var tableFunction = tableFunctionMapper(
new SqlTableFunctionArgument(args, tableFactor.Alias?.Name.Value, exprVisitor, new EmitData())
Expand Down Expand Up @@ -848,7 +850,7 @@ private RelationData VisitTableFunctionJoin(Join join, RelationData parent)
GetTableFunctionNameAndArgs(join.Relation, out var name, out var args);

var tableFunctionMapper = sqlFunctionRegister.GetTableMapper(name);
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister);
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister, globalFunctionOptions);

var tableFunction = tableFunctionMapper(
new SqlTableFunctionArgument(args, join.Relation?.Alias?.Name.Value, exprVisitor, parent.EmitData)
Expand Down Expand Up @@ -1099,7 +1101,7 @@ literalValue.Value is Value.Number number &&

if (leftOuter.JoinConstraint is JoinConstraint.On on)
{
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister);
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister, globalFunctionOptions);
var condition = exprVisitor.Visit(on.Expression, joinEmitData);
joinRelation.Expression = condition.Expr;
}
Expand All @@ -1113,7 +1115,7 @@ literalValue.Value is Value.Number number &&
joinRelation.Type = JoinType.Inner;
if (inner.JoinConstraint is JoinConstraint.On on)
{
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister);
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister, globalFunctionOptions);
var condition = exprVisitor.Visit(on.Expression, joinEmitData);
joinRelation.Expression = condition.Expr;
}
Expand All @@ -1127,7 +1129,7 @@ literalValue.Value is Value.Number number &&
joinRelation.Type = JoinType.Right;
if (rightJoin.JoinConstraint is JoinConstraint.On on)
{
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister);
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister, globalFunctionOptions);
var condition = exprVisitor.Visit(on.Expression, joinEmitData);
joinRelation.Expression = condition.Expr;
}
Expand All @@ -1141,7 +1143,7 @@ literalValue.Value is Value.Number number &&
joinRelation.Type = JoinType.Outer;
if (fullOuterJoin.JoinConstraint is JoinConstraint.On on)
{
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister);
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister, globalFunctionOptions);
var condition = exprVisitor.Visit(on.Expression, joinEmitData);
joinRelation.Expression = condition.Expr;
}
Expand Down Expand Up @@ -1286,7 +1288,7 @@ protected override RelationData VisitValuesExpression(SetExpression.ValuesExpres
foreach (var row in valuesExpression.Values.Rows)
{
List<Expressions.Expression> expressions = new List<Expressions.Expression>();
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister);
var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister, globalFunctionOptions);
foreach(var expr in row)
{
var condition = exprVisitor.Visit(expr, emitData);
Expand Down Expand Up @@ -1320,5 +1322,41 @@ protected override RelationData VisitValuesExpression(SetExpression.ValuesExpres
};
return new RelationData(relation, projectEmitData);
}

protected override RelationData VisitSetVariable(Statement.SetVariable setVariable)
{
var variableString = setVariable.Variables.ToString();

if (variableString.StartsWith("function."))
{

var dotIndex = variableString.IndexOf('.', 9);

if (dotIndex > 0)
{
var funcName = variableString.Substring(9, dotIndex - 9);
var optionName = variableString.Substring(dotIndex + 1);

var value = string.Join(".", setVariable.Value!.Select(x => x.ToSql()));

if (!globalFunctionOptions.TryGetValue(funcName, out var functionOptions))
{
functionOptions = new SortedDictionary<string, string>(StringComparer.OrdinalIgnoreCase);
globalFunctionOptions.Add(funcName, functionOptions);
}
functionOptions[optionName] = value;
return default!;
}
else
{
throw new NotSupportedException("Invalid function option format.");
}

}
else
{
throw new NotSupportedException($"{variableString} is not supported");
}
}
}
}
Loading
Loading