Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add support for udf_format_number and length for binary
  • Loading branch information
chenghao-intel committed Jul 16, 2015
commit 52274f73a69d37151876d49e0f709ad3b45ba8c3
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,12 @@ object FunctionRegistry {
expression[Base64]("base64"),
expression[Encode]("encode"),
expression[Decode]("decode"),
expression[StringInstr]("instr"),
expression[FormatNumber]("format_number"),
expression[Lower]("lcase"),
expression[Lower]("lower"),
expression[StringLength]("length"),
expression[Length]("length"),
expression[Levenshtein]("levenshtein"),
expression[StringInstr]("instr"),
expression[StringLocate]("locate"),
expression[StringLPad]("lpad"),
expression[StringTrimLeft]("ltrim"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@

package org.apache.spark.sql.catalyst.expressions

import java.text.DecimalFormat
import java.util.Locale
import java.util.regex.Pattern

import org.apache.commons.lang3.StringUtils

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.codegen._
Expand Down Expand Up @@ -553,17 +552,23 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
}

/**
* A function that return the length of the given string expression.
* A function that return the length of the given string or binary expression.
*/
case class StringLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
case class Length(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = IntegerType
override def inputTypes: Seq[DataType] = Seq(StringType)
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType))

protected override def nullSafeEval(string: Any): Any =
string.asInstanceOf[UTF8String].numChars
protected override def nullSafeEval(value: Any): Any = child.dataType match {
case StringType => value.asInstanceOf[UTF8String].numChars
case BinaryType => value.asInstanceOf[Array[Byte]].length
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c => s"($c).numChars()")
child.dataType match {
case StringType => defineCodeGen(ctx, ev, c => s"($c).numChars()")
case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length")
case NullType => defineCodeGen(ctx, ev, c => s"-1")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't need to support NullType here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will causes exception in StringFunctionSuite, as we will not run Analyzer at all there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can just remove that test case, can't you?

checkEvaluation(Length(Literal.create(null, NullType)), null, create_row(null))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, yes, we can do that now since you've handled the NullType in a single place.

}
}

override def prettyName: String = "length"
Expand Down Expand Up @@ -668,3 +673,74 @@ case class Encode(value: Expression, charset: Expression)
}
}

/**
* Formats the number X to a format like '#,###,###.##', rounded to D decimal places,
* and returns the result as a string. If D is 0, the result has no decimal point or
* fractional part.
*/
case class FormatNumber(x: Expression, d: Expression)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

override prettyName

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: this is done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, yes, it's done, but in the end of this class code.

extends BinaryExpression with ExpectsInputTypes {

override def left: Expression = x
override def right: Expression = d
override def dataType: DataType = StringType
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)
override def foldable: Boolean = x.foldable && d.foldable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think u need to define foldable and nullable

override def nullable: Boolean = x.nullable || d.nullable

@transient
private var lastDValue: Int = -100

@transient
private val pattern: StringBuffer = new StringBuffer()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be a lazy transient val so scala will initialize pattern on the executors after serialization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, @rxin, I am not so sure your mean, do you mean pattern probably be null on the executors after serialization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the pattern depends on the input rows, so we don't need to serialize the pattern from the driver side, but in runtime.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what i'm worried about is whether we would get a null value on the executors, because "new StringBuffer()" is not called.


@transient
private val numberFormat: DecimalFormat = new DecimalFormat("")

override def eval(input: InternalRow): Any = {
val xObject = x.eval(input)
if (xObject == null) {
return null
}

val dObject = d.eval(input)

if (dObject == null || dObject.asInstanceOf[Int] < 0) {
throw new IllegalArgumentException(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does hive also throw an exception? we could also just return null ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic is copied from Hive. Let's keep the same behavior?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just find it really weird to throw a runtime illegal argument exception... imagine you have some large dataset, and the very last record has d < 0...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, will update that.

s"Argument 2 of function FORMAT_NUMBER must be >= 0, but $dObject was found")
}
val dValue = dObject.asInstanceOf[Int]

if (dValue != lastDValue) {
// construct a new DecimalFormat only if a new dValue
pattern.delete(0, pattern.length())
pattern.append("#,###,###,###,###,###,##0")

// decimal place
if (dValue > 0) {
pattern.append(".")

var i = 0
while (i < dValue) {
i += 1
pattern.append("0")
}
}
val dFormat = new DecimalFormat(pattern.toString())
lastDValue = dValue;
numberFormat.applyPattern(dFormat.toPattern())
}

x.dataType match {
case ByteType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Byte]))
case ShortType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Short]))
case FloatType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Float]))
case IntegerType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Int]))
case LongType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Long]))
case DoubleType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Double]))
case _: DecimalType =>
UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Decimal].toJavaBigDecimal))
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType}
import org.apache.spark.sql.types._


class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -216,15 +216,6 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}

test("length for string") {
val a = 'a.string.at(0)
checkEvaluation(StringLength(Literal("abc")), 3, create_row("abdef"))
checkEvaluation(StringLength(a), 5, create_row("abdef"))
checkEvaluation(StringLength(a), 0, create_row(""))
checkEvaluation(StringLength(a), null, create_row(null))
checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef"))
}

test("ascii for string") {
val a = 'a.string.at(0)
checkEvaluation(Ascii(Literal("efg")), 101, create_row("abdef"))
Expand Down Expand Up @@ -426,4 +417,47 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(
StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1)
}

test("length for string / binary") {
val a = 'a.string.at(0)
val b = 'b.binary.at(0)
val bytes = Array[Byte](1, 2, 3, 1, 2)
val string = "abdef"

// scalastyle:off
// non ascii characters are not allowed in the source code, so we disable the scalastyle.
checkEvaluation(Length(Literal("a花花c")), 4, create_row(string))
// scalastyle:on
checkEvaluation(Length(Literal(bytes)), 5, create_row(Array[Byte]()))

checkEvaluation(Length(a), 5, create_row(string))
checkEvaluation(Length(b), 5, create_row(bytes))

checkEvaluation(Length(a), 0, create_row(""))
checkEvaluation(Length(b), 0, create_row(Array[Byte]()))

checkEvaluation(Length(a), null, create_row(null))
checkEvaluation(Length(b), null, create_row(null))

checkEvaluation(Length(Literal.create(null, StringType)), null, create_row(string))
checkEvaluation(Length(Literal.create(null, BinaryType)), null, create_row(bytes))

checkEvaluation(Length(Literal.create(null, NullType)), null, create_row(null))
}

test("number format") {
checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal(3)), "4.000")
checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal(3)), "4.000")
checkEvaluation(FormatNumber(Literal(4.0f), Literal(3)), "4.000")
checkEvaluation(FormatNumber(Literal(4), Literal(3)), "4.000")
checkEvaluation(FormatNumber(Literal(12831273.23481d), Literal(3)), "12,831,273.235")
checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal(0)), "12,831,274")
checkEvaluation(FormatNumber(Literal(123123324123L), Literal(3)), "123,123,324,123.000")
checkEvaluation(
FormatNumber(
Literal(Decimal(123123324123L) * Decimal(123123.21234d)), Literal(4)),
"15,159,339,180,002,773.2778")
checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null)
checkEvaluation(FormatNumber(Literal.create(null, NullType), Literal(3)), null)
}
}
30 changes: 26 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1685,20 +1685,42 @@ object functions {
//////////////////////////////////////////////////////////////////////////////////////////////

/**
* Computes the length of a given string value.
* Computes the length of a given string / binary value
*
* @group string_funcs
* @since 1.5.0
*/
def strlen(e: Column): Column = StringLength(e.expr)
def length(e: Column): Column = Length(e.expr)

/**
* Computes the length of a given string column.
* Computes the length of a given string / binary column
*
* @group string_funcs
* @since 1.5.0
*/
def strlen(columnName: String): Column = strlen(Column(columnName))
def length(columnName: String): Column = length(Column(columnName))

/**
* Formats the number X to a format like '#,###,###.##', rounded to D decimal places,
* and returns the result as a string. If D is 0, the result has no decimal point or
* fractional part.
*
* @group string_funcs
* @since 1.5.0
*/
def formatNumber(x: Column, d: Column): Column = FormatNumber(x.expr, d.expr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use format_number to be consistent with sql.


/**
* Formats the number X to a format like '#,###,###.##', rounded to D decimal places,
* and returns the result as a string. If D is 0, the result has no decimal point or
* fractional part.
*
* @group string_funcs
* @since 1.5.0
*/
def formatNumber(columnXName: String, columnDName: String): Column = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second argument possible be constant integer in most of the real world cases, leave it for the further discussion, as we have lots of similar cases, @rxin maybe you can change them all in a single PR, after all of the expressions jira issus resolved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd just support a constant number, without column name here. basically two functions

format_number(e: Column, d: Int)

format_number(columnName: String, d: Int)

formatNumber(Column(columnXName), Column(columnDName))
}

/**
* Computes the Levenshtein distance of the two given strings.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,17 +208,6 @@ class DataFrameFunctionsSuite extends QueryTest {
Row(2743272264L, 2180413220L))
}

test("string length function") {
val df = Seq(("abc", "")).toDF("a", "b")
checkAnswer(
df.select(strlen($"a"), strlen("b")),
Row(3, 0))

checkAnswer(
df.selectExpr("length(a)", "length(b)"),
Row(3, 0))
}

test("Levenshtein distance") {
val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r")
checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1)))
Expand Down Expand Up @@ -433,11 +422,91 @@ class DataFrameFunctionsSuite extends QueryTest {
val doubleData = Seq((7.2, 4.1)).toDF("a", "b")
checkAnswer(
doubleData.select(pmod('a, 'b)),
Seq(Row(3.1000000000000005)) // same as hive
Seq(Row(3.1000000000000005)) // same as hive
)
checkAnswer(
doubleData.select(pmod(lit(2), lit(Int.MaxValue))),
Seq(Row(2))
)
}

test("string / binary length function") {
val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c")
checkAnswer(
df.select(length($"a"), length("a"), length($"b"), length("b")),
Row(3, 3, 4, 4))

checkAnswer(
df.selectExpr("length(a)", "length(b)"),
Row(3, 4))

intercept[AnalysisException] {
checkAnswer(
df.selectExpr("length(c)"), // int type of the argument is unacceptable
Row("5.0000"))
}
}

test("number format function") {
val tuple =
("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short],
3.13223f, 4, 5L, 6.48173d, Decimal(7.128381))
val df =
Seq(tuple)
.toDF(
"a", // string "aa"
"b", // byte 1
"c", // short 2
"d", // float 3.13223f
"e", // integer 4
"f", // long 5L
"g", // double 6.48173d
"h") // decimal 7.128381

checkAnswer(
df.select(
formatNumber($"f", $"e"),
formatNumber("f", "e")),
Row("5.0000", "5.0000"))

checkAnswer(
df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer
Row("1.0000"))

checkAnswer(
df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer
Row("2.0000"))

checkAnswer(
df.selectExpr("format_number(d, e)"), // convert the 1st argument to double
Row("3.1322"))

checkAnswer(
df.selectExpr("format_number(e, e)"), // not convert anything
Row("4.0000"))

checkAnswer(
df.selectExpr("format_number(f, e)"), // not convert anything
Row("5.0000"))

checkAnswer(
df.selectExpr("format_number(g, e)"), // not convert anything
Row("6.4817"))

checkAnswer(
df.selectExpr("format_number(h, e)"), // not convert anything
Row("7.1284"))

intercept[AnalysisException] {
checkAnswer(
df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable
Row("5.0000"))
}

intercept[AnalysisException] {
checkAnswer(
df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable
Row("5.0000"))
}
}
}