Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add ascii/base64/unbase64/encode/decode functions
  • Loading branch information
chenghao-intel committed Jul 3, 2015
commit 491ce7b536c4c9d73b38acc5cd05f2ecfaeeae1f
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,16 @@ object FunctionRegistry {
expression[Sum]("sum"),

// string functions
expression[Ascii]("ascii"),
expression[Base64]("base64"),
expression[Encode]("encode"),
expression[Decode]("decode"),
expression[Lower]("lcase"),
expression[Lower]("lower"),
expression[StringLength]("length"),
expression[Substring]("substr"),
expression[Substring]("substring"),
expression[UnBase64]("unbase64"),
expression[Upper]("ucase"),
expression[UnHex]("unhex"),
expression[Upper]("upper")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import java.util.regex.Pattern
import java.nio.charset._

import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.codegen._
Expand Down Expand Up @@ -298,3 +299,131 @@ case class StringLength(child: Expression) extends UnaryExpression with ExpectsI

override def prettyName: String = "length"
}

/**
* Returns the numeric value of the first character of str.
*/
case class Ascii(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = IntegerType
override def expectedChildTypes: Seq[DataType] = Seq(StringType)

override def eval(input: InternalRow): Any = {
val string = child.eval(input)
if (string == null) {
null
} else {
val bytes = string.asInstanceOf[UTF8String].getBytes
if (bytes.length > 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what should the behavior be if it is a non-ascii utf8 string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied the logic from Hive, Hive doesn't check if it's a utf8 string.

bytes(0).asInstanceOf[Int]
} else {
0
}
}
}

override def toString: String = s"ascii($child)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upper case these all?

}

/**
* Converts the argument from binary to a base 64 string.
*/
case class Base64(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = StringType
override def expectedChildTypes: Seq[DataType] = Seq(BinaryType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does hive do any auto cast? if not - maybe we should not use ExpectsInputTypes. (I'm thinking about renaming ExpectsInputTypes to something else like AutoConvertInputTypes)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hive would raise a semantic exception for non-binary type parameter in base64.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok then we shouldn't use expectsInputType. We should just do the explicit type here. cc @cloud-fan

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should override checkInputTypes here, as expectsInputType won't report type check error...


override def eval(input: InternalRow): Any = {
val bytes = child.eval(input)
if (bytes == null) {
null
} else {
UTF8String.fromBytes(
org.apache.commons.codec.binary.Base64.encodeBase64(
bytes.asInstanceOf[Array[Byte]]))
}
}

override def toString: String = s"base64($child)"
}

/**
* Converts the argument from a base 64 string to BINARY.
*/
case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = BinaryType
override def expectedChildTypes: Seq[DataType] = Seq(StringType)

override def eval(input: InternalRow): Any = {
val string = child.eval(input)
if (string == null) {
null
} else {
org.apache.commons.codec.binary.Base64.decodeBase64(string.asInstanceOf[UTF8String].toString)
}
}

override def toString: String = s"unbase64($child)"
}

/**
* Decodes the first argument into a String using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null. (As of Hive 0.12.0.).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove "As of Hive 0.12.0"

*/
case class Decode(bin: Expression, charset: Expression)
extends Expression with ExpectsInputTypes {
override def children: Seq[Expression] = bin :: charset :: Nil
override def foldable: Boolean = bin.foldable && charset.foldable
override def nullable: Boolean = bin.nullable || charset.nullable
override def dataType: DataType = StringType
override def expectedChildTypes: Seq[DataType] = Seq(BinaryType, StringType)

override def eval(input: InternalRow): Any = {
val l = bin.eval(input)
if (l == null) {
null
} else {
val r = charset.eval(input)
if (r == null) {
null
} else {
val fromCharset = r.asInstanceOf[UTF8String].toString
UTF8String.fromString(new String(l.asInstanceOf[Array[Byte]], fromCharset))
}
}
}

override def toString: String = s"decode($bin, $charset)"
}

/**
* Encodes the first argument into a BINARY using the provided character set
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indent

* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null. (As of Hive 0.12.0.)
*/
case class Encode(value: Expression, charset: Expression)
extends Expression with ExpectsInputTypes {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here extend BinaryExpression too

override def children: Seq[Expression] = value :: charset :: Nil
override def foldable: Boolean = value.foldable && charset.foldable
override def nullable: Boolean = value.nullable || charset.nullable
override def dataType: DataType = BinaryType
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)

override def eval(input: InternalRow): Any = {
val l = value.eval(input)
if (l == null) {
null
} else {
val r = charset.eval(input)
if (r == null) {
null
} else {
val toCharset = r.asInstanceOf[UTF8String].toString
l.asInstanceOf[UTF8String].toString.getBytes(toCharset)
}
}
}

override def toString: String = s"encode($value, $charset)"
}


Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types.{IntegerType, StringType}
import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType}


class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -217,11 +217,59 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("length for string") {
val regEx = 'a.string.at(0)
val a = 'a.string.at(0)
checkEvaluation(StringLength(Literal("abc")), 3, create_row("abdef"))
checkEvaluation(StringLength(regEx), 5, create_row("abdef"))
checkEvaluation(StringLength(regEx), 0, create_row(""))
checkEvaluation(StringLength(regEx), null, create_row(null))
checkEvaluation(StringLength(a), 5, create_row("abdef"))
checkEvaluation(StringLength(a), 0, create_row(""))
checkEvaluation(StringLength(a), null, create_row(null))
checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef"))
}

test("ascii for string") {
val a = 'a.string.at(0)
checkEvaluation(Ascii(Literal("efg")), 101, create_row("abdef"))
checkEvaluation(Ascii(a), 97, create_row("abdef"))
checkEvaluation(Ascii(a), 0, create_row(""))
checkEvaluation(Ascii(a), null, create_row(null))
checkEvaluation(Ascii(Literal.create(null, StringType)), null, create_row("abdef"))
}

test("base64/unbase64 for string") {
val a = 'a.string.at(0)
val b = 'b.binary.at(0)

checkEvaluation(Base64(Literal(Array[Byte](1,2,3,4))), "AQIDBA==", create_row("abdef"))
checkEvaluation(Base64(UnBase64(Literal("AQIDBA=="))), "AQIDBA==", create_row("abdef"))
checkEvaluation(Base64(UnBase64(Literal(""))), "", create_row("abdef"))
checkEvaluation(Base64(UnBase64(Literal.create(null, StringType))), null, create_row("abdef"))
checkEvaluation(Base64(UnBase64(a)), "AQIDBA==", create_row("AQIDBA=="))

checkEvaluation(Base64(b), "AQIDBA==", create_row(Array[Byte](1,2,3,4)))
checkEvaluation(Base64(b), "", create_row(Array[Byte]()))
checkEvaluation(Base64(b), null, create_row(null))
checkEvaluation(Base64(Literal.create(null, StringType)), null, create_row("abdef"))

checkEvaluation(UnBase64(a), null, create_row(null))
checkEvaluation(UnBase64(Literal.create(null, StringType)), null, create_row("abdef"))
}

test("encode/decode for string") {
val a = 'a.string.at(0)
val b = 'b.binary.at(0)

checkEvaluation(
Decode(Encode(Literal("大千世界"), Literal("UTF-16LE")), Literal("UTF-16LE")), "大千世界")
checkEvaluation(
Decode(Encode(a, Literal("utf-8")), Literal("utf-8")), "大千世界", create_row("大千世界"))
checkEvaluation(
Decode(Encode(a, Literal("utf-8")), Literal("utf-8")), "", create_row(""))

checkEvaluation(Encode(a, Literal("utf-8")), null, create_row(null))
checkEvaluation(Encode(Literal.create(null, StringType), Literal("utf-8")), null)
checkEvaluation(Encode(a, Literal.create(null, StringType)), null, create_row(""))

checkEvaluation(Decode(b, Literal("utf-8")), null, create_row(null))
checkEvaluation(Decode(Literal.create(null, BinaryType), Literal("utf-8")), null)
checkEvaluation(Decode(b, Literal.create(null, StringType)), null, create_row(null))
}
}
93 changes: 93 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1543,18 +1543,111 @@ object functions {

/**
* Computes the length of a given string value
*
* @group string_funcs
* @since 1.5.0
*/
def strlen(e: Column): Column = StringLength(e.expr)

/**
* Computes the length of a given string column
*
* @group string_funcs
* @since 1.5.0
*/
def strlen(columnName: String): Column = strlen(Column(columnName))

/**
* Computes the numeric value of the first character of the specified string value.
*
* @group string_funcs
* @since 1.5.0
*/
def ascii(e: Column): Column = Ascii(e.expr)

/**
* Computes the numeric value of the first character of the specified string column.
*
* @group string_funcs
* @since 1.5.0
*/
def ascii(columnName: String): Column = ascii(Column(columnName))

/**
* Computes the specified value from binary to a base 64 string.
*
* @group string_funcs
* @since 1.5.0
*/
def base64(e: Column): Column = Base64(e.expr)

/**
* Computes the specified column from binary to a base 64 string.
*
* @group string_funcs
* @since 1.5.0
*/
def base64(columnName: String): Column = base64(Column(columnName))

/**
* Computes the specified value from a base 64 string to binary.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Base64 is a single word

*
* @group string_funcs
* @since 1.5.0
*/
def unbase64(e: Column): Column = UnBase64(e.expr)

/**
* Computes the specified column from a base 64 string to binary.
*
* @group string_funcs
* @since 1.5.0
*/
def unbase64(columnName: String): Column = unbase64(Column(columnName))

/**
* Computes the first argument into a binary from a string using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null.
*
* @group string_funcs
* @since 1.5.0
*/
def encode(value: Column, charset: Column): Column = Encode(value.expr, charset.expr)

/**
* Computes the first argument into a binary from a string using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null.
*
* @group string_funcs
* @since 1.5.0
*/
def encode(columnName: String, charsetColumnName: String): Column =
encode(Column(columnName), Column(charsetColumnName))

/**
* Computes the first argument into a string from a binary using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null.
*
* @group string_funcs
* @since 1.5.0
*/
def decode(value: Column, charset: Column): Column = Decode(value.expr, charset.expr)

/**
* Computes the first argument into a string from a binary using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null.
*
* @group string_funcs
* @since 1.5.0
*/
def decode(columnName: String, charsetColumnName: String): Column =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i am not sure if this makes sense -- since it is more likely users want to decode by typing in the charset, rather than using a column for that...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in most of existed DF api, we take the string as the column name, should we break this pattern? Actually, it seems redundant for most of DF functions, which take the string columns as parameters, as well as the Column types. Of course this is a big change to the existed user code, we probably don't want to do the clean up right now, but we can stop adding the string (column name) version of DF functions during the Hive UDF rewriting, what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just change this one to take charset: String, rather than a column.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically two decode:

def decode(column: Column, charset: String): Column
def decode(columnName: String, charset: String): Column

same for encode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I will update it soon.

decode(Column(columnName), Column(charsetColumnName))


//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,4 +225,39 @@ class DataFrameFunctionsSuite extends QueryTest {
Row(l)
})
}

test("string ascii function") {
val df = Seq(("abc", "")).toDF("a", "b")
checkAnswer(
df.select(ascii($"a"), ascii("b")),
Row(97, 0))

checkAnswer(
df.selectExpr("ascii(a)", "ascii(b)"),
Row(97, 0))
}

test("string base64/unbase64 function") {
val bytes = Array[Byte](1, 2, 3, 4)
val df = Seq((bytes, "AQIDBA==")).toDF("a", "b")
checkAnswer(
df.select(base64("a"), base64($"a"), unbase64("b"), unbase64($"b")),
Row("AQIDBA==", "AQIDBA==", bytes, bytes))

checkAnswer(
df.selectExpr("base64(a)", "unbase64(b)"),
Row("AQIDBA==", bytes))
}

test("string encode/decode function") {
val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116)
val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c")
checkAnswer(
df.select(encode($"a", $"b"), encode("a", "b"), decode($"c", $"b"), decode("c", "b")),
Row(bytes, bytes, "大千世界", "大千世界"))

checkAnswer(
df.selectExpr("encode(a, b)", "decode(c, b)"),
Row(bytes, "大千世界"))
}
}