Skip to content
Closed
Prev Previous commit
Next Next commit
small refactor
  • Loading branch information
stefankandic committed Jun 27, 2024
commit e4a29b90ca9a78b86ac3b1ec9f4342d92cd2b45f
Original file line number Diff line number Diff line change
Expand Up @@ -503,94 +503,94 @@ object DataSourceStrategy
/**
* Creates a collation aware filter if the input data type is string with non-default collation
*/
private def collationAwareFilter(filter: sources.Filter, dataType: DataType): Option[Filter] = {
private def collationAwareFilter(filter: sources.Filter, dataType: DataType): Filter = {
if (!SchemaUtils.hasNonUTF8BinaryCollation(dataType)) {
return Some(filter)
return filter
}

filter match {
case sources.EqualTo(attribute, value) =>
Some(CollatedEqualTo(attribute, value, dataType))
CollatedEqualTo(attribute, value, dataType)
case sources.EqualNullSafe(attribute, value) =>
Some(CollatedEqualNullSafe(attribute, value, dataType))
CollatedEqualNullSafe(attribute, value, dataType)
case sources.GreaterThan(attribute, value) =>
Some(CollatedGreaterThan(attribute, value, dataType))
CollatedGreaterThan(attribute, value, dataType)
case sources.GreaterThanOrEqual(attribute, value) =>
Some(CollatedGreaterThanOrEqual(attribute, value, dataType))
CollatedGreaterThanOrEqual(attribute, value, dataType)
case sources.LessThan(attribute, value) =>
Some(CollatedLessThan(attribute, value, dataType))
CollatedLessThan(attribute, value, dataType)
case sources.LessThanOrEqual(attribute, value) =>
Some(CollatedLessThanOrEqual(attribute, value, dataType))
CollatedLessThanOrEqual(attribute, value, dataType)
case sources.In(attribute, values) =>
Some(CollatedIn(attribute, values, dataType))
CollatedIn(attribute, values, dataType)
case sources.StringStartsWith(attribute, value) =>
Some(CollatedStringStartsWith(attribute, value, dataType))
CollatedStringStartsWith(attribute, value, dataType)
case sources.StringEndsWith(attribute, value) =>
Some(CollatedStringEndsWith(attribute, value, dataType))
CollatedStringEndsWith(attribute, value, dataType)
case sources.StringContains(attribute, value) =>
Some(CollatedStringContains(attribute, value, dataType))
CollatedStringContains(attribute, value, dataType)
case other =>
Some(other)
other
}
}

private def translateLeafNodeFilter(
predicate: Expression,
pushableColumn: PushableColumnBase): Option[Filter] = predicate match {
case expressions.EqualTo(e @ pushableColumn(name), Literal(v, t)) =>
collationAwareFilter(sources.EqualTo(name, convertToScala(v, t)), e.dataType)
Some(collationAwareFilter(sources.EqualTo(name, convertToScala(v, t)), e.dataType))
case expressions.EqualTo(Literal(v, t), e @ pushableColumn(name)) =>
collationAwareFilter(sources.EqualTo(name, convertToScala(v, t)), e.dataType)
Some(collationAwareFilter(sources.EqualTo(name, convertToScala(v, t)), e.dataType))

case expressions.EqualNullSafe(e @ pushableColumn(name), Literal(v, t)) =>
collationAwareFilter(sources.EqualNullSafe(name, convertToScala(v, t)), e.dataType)
Some(collationAwareFilter(sources.EqualNullSafe(name, convertToScala(v, t)), e.dataType))
case expressions.EqualNullSafe(Literal(v, t), e @ pushableColumn(name)) =>
collationAwareFilter(sources.EqualNullSafe(name, convertToScala(v, t)), e.dataType)
Some(collationAwareFilter(sources.EqualNullSafe(name, convertToScala(v, t)), e.dataType))

case expressions.GreaterThan(e @ pushableColumn(name), Literal(v, t)) =>
collationAwareFilter(sources.GreaterThan(name, convertToScala(v, t)), e.dataType)
Some(collationAwareFilter(sources.GreaterThan(name, convertToScala(v, t)), e.dataType))
case expressions.GreaterThan(Literal(v, t), e @ pushableColumn(name)) =>
collationAwareFilter(sources.LessThan(name, convertToScala(v, t)), e.dataType)
Some(collationAwareFilter(sources.LessThan(name, convertToScala(v, t)), e.dataType))

case expressions.LessThan(e @ pushableColumn(name), Literal(v, t)) =>
collationAwareFilter(sources.LessThan(name, convertToScala(v, t)), e.dataType)
Some(collationAwareFilter(sources.LessThan(name, convertToScala(v, t)), e.dataType))
case expressions.LessThan(Literal(v, t), e @ pushableColumn(name)) =>
collationAwareFilter(sources.GreaterThan(name, convertToScala(v, t)), e.dataType)
Some(collationAwareFilter(sources.GreaterThan(name, convertToScala(v, t)), e.dataType))

case expressions.GreaterThanOrEqual(e @ pushableColumn(name), Literal(v, t)) =>
collationAwareFilter(sources.GreaterThanOrEqual(name, convertToScala(v, t)), e.dataType)
Some(collationAwareFilter(sources.GreaterThanOrEqual(name, convertToScala(v, t)), e.dataType))
case expressions.GreaterThanOrEqual(Literal(v, t), e @ pushableColumn(name)) =>
collationAwareFilter(sources.LessThanOrEqual(name, convertToScala(v, t)), e.dataType)
Some(collationAwareFilter(sources.LessThanOrEqual(name, convertToScala(v, t)), e.dataType))

case expressions.LessThanOrEqual(e @ pushableColumn(name), Literal(v, t)) =>
collationAwareFilter(sources.LessThanOrEqual(name, convertToScala(v, t)), e.dataType)
Some(collationAwareFilter(sources.LessThanOrEqual(name, convertToScala(v, t)), e.dataType))
case expressions.LessThanOrEqual(Literal(v, t), e @ pushableColumn(name)) =>
collationAwareFilter(sources.GreaterThanOrEqual(name, convertToScala(v, t)), e.dataType)
Some(collationAwareFilter(sources.GreaterThanOrEqual(name, convertToScala(v, t)), e.dataType))

case expressions.InSet(e @ pushableColumn(name), set) =>
val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType)
collationAwareFilter(sources.In(name, set.toArray.map(toScala)), e.dataType)
Some(collationAwareFilter(sources.In(name, set.toArray.map(toScala)), e.dataType))

// Because we only convert In to InSet in Optimizer when there are more than certain
// items. So it is possible we still get an In expression here that needs to be pushed
// down.
case expressions.In(e @ pushableColumn(name), list) if list.forall(_.isInstanceOf[Literal]) =>
val hSet = list.map(_.eval(EmptyRow))
val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType)
collationAwareFilter(sources.In(name, hSet.toArray.map(toScala)), e.dataType)
Some(collationAwareFilter(sources.In(name, hSet.toArray.map(toScala)), e.dataType))

case expressions.IsNull(pushableColumn(name)) =>
Some(sources.IsNull(name))
case expressions.IsNotNull(pushableColumn(name)) =>
Some(sources.IsNotNull(name))
case expressions.StartsWith(e @ pushableColumn(name), Literal(v: UTF8String, StringType)) =>
collationAwareFilter(sources.StringStartsWith(name, v.toString), e.dataType)
Some(collationAwareFilter(sources.StringStartsWith(name, v.toString), e.dataType))

case expressions.EndsWith(e @ pushableColumn(name), Literal(v: UTF8String, StringType)) =>
collationAwareFilter(sources.StringEndsWith(name, v.toString), e.dataType)
Some(collationAwareFilter(sources.StringEndsWith(name, v.toString), e.dataType))

case expressions.Contains(e @ pushableColumn(name), Literal(v: UTF8String, StringType)) =>
collationAwareFilter(sources.StringContains(name, v.toString), e.dataType)
Some(collationAwareFilter(sources.StringContains(name, v.toString), e.dataType))

case expressions.Literal(true, BooleanType) =>
Some(sources.AlwaysTrue)
Expand Down