Skip to content
Closed
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -2988,6 +2988,12 @@
],
"sqlState" : "42710"
},
"MALFORMED_CHARACTER_CODING" : {
"message" : [
"Invalid value found when performing <function> with <charset>"
],
"sqlState" : "22000"
},
"MALFORMED_CSV_RECORD" : {
"message" : [
"Malformed CSV record: <badRecord>"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [decode(cast(g#0 as binary), UTF-8, false) AS decode(g, UTF-8)#0]
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.StringDecode, StringType, decode, cast(g#0 as binary), UTF-8, false, false, BinaryType, StringTypeAnyCollation, BooleanType, BooleanType, true, true, true) AS decode(g, UTF-8)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [encode(g#0, UTF-8, false) AS encode(g, UTF-8)#0]
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.Encode, BinaryType, encode, g#0, UTF-8, false, false, StringTypeAnyCollation, StringTypeAnyCollation, BooleanType, BooleanType, true, true, true) AS encode(g, UTF-8)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [encode(g#0, UTF-8, false) AS to_binary(g, utf-8)#0]
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.Encode, BinaryType, encode, g#0, UTF-8, false, false, StringTypeAnyCollation, StringTypeAnyCollation, BooleanType, BooleanType, true, true, true) AS to_binary(g, utf-8)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@

package org.apache.spark.sql.catalyst.expressions

import java.io.UnsupportedEncodingException
import java.nio.{ByteBuffer, CharBuffer}
import java.nio.charset.{CharacterCodingException, Charset, CodingErrorAction, IllegalCharsetNameException, UnsupportedCharsetException}
import java.text.{BreakIterator, DecimalFormat, DecimalFormatSymbols}
import java.util.{Base64 => JBase64}
import java.util.{HashMap, Locale, Map => JMap}

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.QueryContext
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
Expand Down Expand Up @@ -2708,62 +2710,69 @@ case class Decode(params: Seq[Expression], replacement: Expression)
since = "1.5.0",
group = "string_funcs")
// scalastyle:on line.size.limit
case class StringDecode(bin: Expression, charset: Expression, legacyCharsets: Boolean)
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
case class StringDecode(
bin: Expression,
charset: Expression,
legacyCharsets: Boolean,
legacyErrorAction: Boolean)
extends RuntimeReplaceable with ImplicitCastInputTypes {

def this(bin: Expression, charset: Expression) =
this(bin, charset, SQLConf.get.legacyJavaCharsets)
this(bin, charset, SQLConf.get.legacyJavaCharsets, SQLConf.get.legacyCodingErrorAction)

override def left: Expression = bin
override def right: Expression = charset
override def dataType: DataType = SQLConf.get.defaultStringType
override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType, StringTypeAnyCollation)
override def prettyName: String = "decode"
override def toString: String = s"$prettyName($bin, $charset)"

private val supportedCharsets = Set(
"US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE", "UTF-16LE", "UTF-16", "UTF-32")

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
val fromCharset = input2.asInstanceOf[UTF8String].toString
try {
if (legacyCharsets || supportedCharsets.contains(fromCharset.toUpperCase(Locale.ROOT))) {
UTF8String.fromString(new String(input1.asInstanceOf[Array[Byte]], fromCharset))
} else throw new UnsupportedEncodingException
} catch {
case _: UnsupportedEncodingException =>
throw QueryExecutionErrors.invalidCharsetError(prettyName, fromCharset)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (bytes, charset) => {
val fromCharset = ctx.freshName("fromCharset")
val sc = JavaCode.global(
ctx.addReferenceObj("supportedCharsets", supportedCharsets),
supportedCharsets.getClass)
s"""
String $fromCharset = $charset.toString();
try {
if ($legacyCharsets || $sc.contains($fromCharset.toUpperCase(java.util.Locale.ROOT))) {
${ev.value} = UTF8String.fromString(new String($bytes, $fromCharset));
} else {
throw new java.io.UnsupportedEncodingException();
}
} catch (java.io.UnsupportedEncodingException e) {
throw QueryExecutionErrors.invalidCharsetError("$prettyName", $fromCharset);
}
"""
})
}

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): StringDecode =
copy(bin = newLeft, charset = newRight)
override def replacement: Expression = StaticInvoke(
classOf[StringDecode],
SQLConf.get.defaultStringType,
"decode",
Seq(bin, charset, Literal(legacyCharsets), Literal(legacyErrorAction)),
Seq(BinaryType, StringTypeAnyCollation, BooleanType, BooleanType))

override def prettyName: String = "decode"
override def children: Seq[Expression] = Seq(bin, charset)
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(bin = newChildren(0), charset = newChildren(1))
}

object StringDecode {
def apply(bin: Expression, charset: Expression): StringDecode = new StringDecode(bin, charset)
def decode(
input: Array[Byte],
charset: UTF8String,
legacyCharsets: Boolean,
legacyErrorAction: Boolean): UTF8String = {
val fromCharset = charset.toString
if (legacyCharsets || Encode.VALID_CHARSETS.contains(fromCharset.toUpperCase(Locale.ROOT))) {
val decoder = try {
val codingErrorAction = if (legacyErrorAction) {
CodingErrorAction.REPLACE
} else {
CodingErrorAction.REPORT
}
Charset.forName(fromCharset)
.newDecoder()
.onMalformedInput(codingErrorAction)
.onUnmappableCharacter(codingErrorAction)
} catch {
case _: IllegalCharsetNameException |
_: UnsupportedCharsetException |
_: IllegalArgumentException =>
throw QueryExecutionErrors.invalidCharsetError("decode", fromCharset)
}
try {
val cb = decoder.decode(ByteBuffer.wrap(input))
UTF8String.fromString(cb.toString)
} catch {
case _: CharacterCodingException =>
throw QueryExecutionErrors.malformedCharacterCoding("decode", fromCharset)
}
} else {
throw QueryExecutionErrors.invalidCharsetError("decode", fromCharset)
}
}
}

/**
Expand All @@ -2785,59 +2794,76 @@ object StringDecode {
since = "1.5.0",
group = "string_funcs")
// scalastyle:on line.size.limit
case class Encode(str: Expression, charset: Expression, legacyCharsets: Boolean)
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
case class Encode(
str: Expression,
charset: Expression,
legacyCharsets: Boolean,
legacyErrorAction: Boolean)
extends RuntimeReplaceable with ImplicitCastInputTypes {

def this(value: Expression, charset: Expression) =
this(value, charset, SQLConf.get.legacyJavaCharsets)
this(value, charset, SQLConf.get.legacyJavaCharsets, SQLConf.get.legacyCodingErrorAction)

override def left: Expression = str
override def right: Expression = charset
override def dataType: DataType = BinaryType
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation)

private val supportedCharsets = Set(
"US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE", "UTF-16LE", "UTF-16", "UTF-32")

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
val toCharset = input2.asInstanceOf[UTF8String].toString
try {
if (legacyCharsets || supportedCharsets.contains(toCharset.toUpperCase(Locale.ROOT))) {
input1.asInstanceOf[UTF8String].toString.getBytes(toCharset)
} else throw new UnsupportedEncodingException
} catch {
case _: UnsupportedEncodingException =>
throw QueryExecutionErrors.invalidCharsetError(prettyName, toCharset)
}
}
override val replacement: Expression = StaticInvoke(
classOf[Encode],
BinaryType,
"encode",
Seq(
str, charset, Literal(legacyCharsets, BooleanType), Literal(legacyErrorAction, BooleanType)),
Seq(StringTypeAnyCollation, StringTypeAnyCollation, BooleanType, BooleanType))

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (string, charset) => {
val toCharset = ctx.freshName("toCharset")
val sc = JavaCode.global(
ctx.addReferenceObj("supportedCharsets", supportedCharsets),
supportedCharsets.getClass)
s"""
String $toCharset = $charset.toString();
try {
if ($legacyCharsets || $sc.contains($toCharset.toUpperCase(java.util.Locale.ROOT))) {
${ev.value} = $string.toString().getBytes($toCharset);
} else {
throw new java.io.UnsupportedEncodingException();
}
} catch (java.io.UnsupportedEncodingException e) {
throw QueryExecutionErrors.invalidCharsetError("$prettyName", $toCharset);
}"""
})
}
override def toString: String = s"$prettyName($str, $charset)"

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): Encode = copy(str = newLeft, charset = newRight)
override def children: Seq[Expression] = Seq(str, charset)

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(str = newChildren.head, charset = newChildren(1))
}

object Encode {
def apply(value: Expression, charset: Expression): Encode = new Encode(value, charset)

private[expressions] final lazy val VALID_CHARSETS =
Set("US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE", "UTF-16LE", "UTF-16", "UTF-32")

def encode(
input: UTF8String,
charset: UTF8String,
legacyCharsets: Boolean,
legacyErrorAction: Boolean): Array[Byte] = {
val toCharset = charset.toString
if (legacyCharsets || VALID_CHARSETS.contains(toCharset.toUpperCase(Locale.ROOT))) {
val encoder = try {
val codingErrorAction = if (legacyErrorAction) {
CodingErrorAction.REPLACE
} else {
CodingErrorAction.REPORT
}
Charset.forName(toCharset)
.newEncoder()
.onMalformedInput(codingErrorAction)
.onUnmappableCharacter(codingErrorAction)
} catch {
case _: IllegalCharsetNameException |
_: UnsupportedCharsetException |
_: IllegalArgumentException =>
throw QueryExecutionErrors.invalidCharsetError("encode", toCharset)
}
try {
val bb = encoder.encode(CharBuffer.wrap(input.toString))
JavaUtils.bufferToArray(bb)
} catch {
case _: CharacterCodingException =>
throw QueryExecutionErrors.malformedCharacterCoding("encode", toCharset)
}
} else {
throw QueryExecutionErrors.invalidCharsetError("encode", toCharset)
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2741,6 +2741,14 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
"charset" -> charset))
}

def malformedCharacterCoding(functionName: String, charset: String): RuntimeException = {
new SparkRuntimeException(
errorClass = "MALFORMED_CHARACTER_CODING",
messageParameters = Map(
"function" -> toSQLId(functionName),
"charset" -> charset))
}

def invalidWriterCommitMessageError(details: String): Throwable = {
new SparkRuntimeException(
errorClass = "INVALID_WRITER_COMMIT_MESSAGE",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5010,6 +5010,14 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val LEGACY_CODING_ERROR_ACTION = buildConf("spark.sql.legacy.codingErrorAction")
.internal()
.doc("When set to true, encode/decode functions replace unmappable characters with mojibake " +
"instead of reporting coding errors.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it should be a fallback conf to ANSI.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reasons I'd like to make it independent of ANSI are:

  • Part of the implication of ANSI is Hive-incompatibility,
  • Hive also reports coding errors, so it was a mistake when we ported this from hive
  • These functions are not ANSI-defined
  • The error behaviors are also not found in ANSI

The reasons mentioned above indicate that this behavior is more of a legacy trait of Spark itself.


val LEGACY_EVAL_CURRENT_TIME = buildConf("spark.sql.legacy.earlyEvalCurrentTime")
.internal()
.doc("When set to true, evaluation and constant folding will happen for now() and " +
Expand Down Expand Up @@ -5986,6 +5994,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def legacyJavaCharsets: Boolean = getConf(SQLConf.LEGACY_JAVA_CHARSETS)

def legacyCodingErrorAction: Boolean = getConf(SQLConf.LEGACY_CODING_ERROR_ACTION)

def legacyEvalCurrentTime: Boolean = getConf(SQLConf.LEGACY_EVAL_CURRENT_TIME)

/** ********************** SQLConf functionality methods ************ */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
test("SPARK-22543: split large if expressions into blocks due to JVM code size limit") {
var strExpr: Expression = Literal("abc")
for (_ <- 1 to 150) {
strExpr = StringDecode(Encode(strExpr, "utf-8"), "utf-8")
strExpr = StringTrimRight(StringTrimLeft(strExpr))
}

val expressions = Seq(If(EqualTo(strExpr, strExpr), strExpr, strExpr))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,15 @@ trait ExpressionEvalHelper extends ScalaCheckDrivenPropertyChecks with PlanTestB
new ArrayBasedMapData(keyArray, valueArray)
}

protected def replace(expr: Expression): Expression = expr match {
case r: RuntimeReplaceable => replace(r.replacement)
case _ => expr.mapChildren(replace)
}

private def prepareEvaluation(expression: Expression): Expression = {
val serializer = new JavaSerializer(new SparkConf()).newInstance()
val resolver = ResolveTimeZone
val expr = resolver.resolveTimeZones(expression)
val expr = resolver.resolveTimeZones(replace(expression))
assert(expr.resolved)
serializer.deserialize(serializer.serialize(expr))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,8 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(StringDecode(b, Literal.create(null, StringType)), null, create_row(null))

// Test escaping of charset
GenerateUnsafeProjection.generate(Encode(a, Literal("\"quote")) :: Nil)
GenerateUnsafeProjection.generate(StringDecode(b, Literal("\"quote")) :: Nil)
GenerateUnsafeProjection.generate(Encode(a, Literal("\"quote")).replacement :: Nil)
GenerateUnsafeProjection.generate(StringDecode(b, Literal("\"quote")).replacement :: Nil)
}

test("initcap unit test") {
Expand Down
Loading