Skip to content
Closed
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -2988,6 +2988,12 @@
],
"sqlState" : "42710"
},
"MALFORMED_CHARACTER_CODING" : {
"message" : [
"Invalid value found when performing <function> with <charset>"
],
"sqlState" : "22000"
},
"MALFORMED_CSV_RECORD" : {
"message" : [
"Malformed CSV record: <badRecord>"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [decode(cast(g#0 as binary), UTF-8, false) AS decode(g, UTF-8)#0]
Project [decode(cast(g#0 as binary), UTF-8, false, false) AS decode(g, UTF-8)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [encode(g#0, UTF-8, false) AS encode(g, UTF-8)#0]
Project [encode(g#0, UTF-8, false, false) AS encode(g, UTF-8)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [encode(g#0, UTF-8, false) AS to_binary(g, utf-8)#0]
Project [encode(g#0, UTF-8, false, false) AS to_binary(g, utf-8)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@

package org.apache.spark.sql.catalyst.expressions

import java.io.UnsupportedEncodingException
import java.nio.{ByteBuffer, CharBuffer}
import java.nio.charset.{CharacterCodingException, Charset, CodingErrorAction, IllegalCharsetNameException, UnsupportedCharsetException}
import java.text.{BreakIterator, DecimalFormat, DecimalFormatSymbols}
import java.util.{Base64 => JBase64}
import java.util.{HashMap, Locale, Map => JMap}

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.QueryContext
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
Expand Down Expand Up @@ -2708,11 +2710,14 @@ case class Decode(params: Seq[Expression], replacement: Expression)
since = "1.5.0",
group = "string_funcs")
// scalastyle:on line.size.limit
case class StringDecode(bin: Expression, charset: Expression, legacyCharsets: Boolean)
case class StringDecode(
bin: Expression,
charset: Expression,
legacyCharsets: Boolean, legacyErrorAction: Boolean)
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {

def this(bin: Expression, charset: Expression) =
this(bin, charset, SQLConf.get.legacyJavaCharsets)
this(bin, charset, SQLConf.get.legacyJavaCharsets, SQLConf.get.legacyCodingErrorAction)

override def left: Expression = bin
override def right: Expression = charset
Expand All @@ -2724,35 +2729,38 @@ case class StringDecode(bin: Expression, charset: Expression, legacyCharsets: Bo

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
val fromCharset = input2.asInstanceOf[UTF8String].toString
try {
if (legacyCharsets || supportedCharsets.contains(fromCharset.toUpperCase(Locale.ROOT))) {
UTF8String.fromString(new String(input1.asInstanceOf[Array[Byte]], fromCharset))
} else throw new UnsupportedEncodingException
} catch {
case _: UnsupportedEncodingException =>
throw QueryExecutionErrors.invalidCharsetError(prettyName, fromCharset)
if (legacyCharsets || supportedCharsets.contains(fromCharset.toUpperCase(Locale.ROOT))) {
val decoder = try {
val codingErrorAction = if (legacyErrorAction) {
CodingErrorAction.REPLACE
} else {
CodingErrorAction.REPORT
}
Charset.forName(fromCharset)
.newDecoder()
.onMalformedInput(codingErrorAction)
.onUnmappableCharacter(codingErrorAction)
} catch {
case _: IllegalCharsetNameException |
_: UnsupportedCharsetException |
_: IllegalArgumentException =>
throw QueryExecutionErrors.invalidCharsetError(prettyName, fromCharset)
}
try {
val cb = decoder.decode(ByteBuffer.wrap(input1.asInstanceOf[Array[Byte]]))
UTF8String.fromString(cb.toString)
} catch {
case _: CharacterCodingException =>
throw QueryExecutionErrors.malformedCharacterCoding(prettyName, fromCharset)
}
} else {
throw QueryExecutionErrors.invalidCharsetError(prettyName, fromCharset)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (bytes, charset) => {
val fromCharset = ctx.freshName("fromCharset")
val sc = JavaCode.global(
ctx.addReferenceObj("supportedCharsets", supportedCharsets),
supportedCharsets.getClass)
s"""
String $fromCharset = $charset.toString();
try {
if ($legacyCharsets || $sc.contains($fromCharset.toUpperCase(java.util.Locale.ROOT))) {
${ev.value} = UTF8String.fromString(new String($bytes, $fromCharset));
} else {
throw new java.io.UnsupportedEncodingException();
}
} catch (java.io.UnsupportedEncodingException e) {
throw QueryExecutionErrors.invalidCharsetError("$prettyName", $fromCharset);
}
"""
})
val expr = ctx.addReferenceObj("this", this)
defineCodeGen(ctx, ev, (bin, charset) => s"(UTF8String) $expr.nullSafeEval($bin, $charset)")
}

override protected def withNewChildrenInternal(
Expand Down Expand Up @@ -2785,11 +2793,15 @@ object StringDecode {
since = "1.5.0",
group = "string_funcs")
// scalastyle:on line.size.limit
case class Encode(str: Expression, charset: Expression, legacyCharsets: Boolean)
case class Encode(
str: Expression,
charset: Expression,
legacyCharsets: Boolean,
legacyErrorAction: Boolean)
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {

def this(value: Expression, charset: Expression) =
this(value, charset, SQLConf.get.legacyJavaCharsets)
this(value, charset, SQLConf.get.legacyJavaCharsets, SQLConf.get.legacyCodingErrorAction)

override def left: Expression = str
override def right: Expression = charset
Expand All @@ -2800,36 +2812,41 @@ case class Encode(str: Expression, charset: Expression, legacyCharsets: Boolean)
private val supportedCharsets = Set(
"US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE", "UTF-16LE", "UTF-16", "UTF-32")


protected override def nullSafeEval(input1: Any, input2: Any): Any = {
val toCharset = input2.asInstanceOf[UTF8String].toString
try {
if (legacyCharsets || supportedCharsets.contains(toCharset.toUpperCase(Locale.ROOT))) {
input1.asInstanceOf[UTF8String].toString.getBytes(toCharset)
} else throw new UnsupportedEncodingException
} catch {
case _: UnsupportedEncodingException =>
throw QueryExecutionErrors.invalidCharsetError(prettyName, toCharset)
if (legacyCharsets || supportedCharsets.contains(toCharset.toUpperCase(Locale.ROOT))) {
val encoder = try {
val codingErrorAction = if (legacyErrorAction) {
CodingErrorAction.REPLACE
} else {
CodingErrorAction.REPORT
}
Charset.forName(toCharset)
.newEncoder()
.onMalformedInput(codingErrorAction)
.onUnmappableCharacter(codingErrorAction)
} catch {
case _: IllegalCharsetNameException |
_: UnsupportedCharsetException |
_: IllegalArgumentException =>
throw QueryExecutionErrors.invalidCharsetError(prettyName, toCharset)
}
try {
val bb = encoder.encode(CharBuffer.wrap(input1.asInstanceOf[UTF8String].toString))
JavaUtils.bufferToArray(bb)
} catch {
case _: CharacterCodingException =>
throw QueryExecutionErrors.malformedCharacterCoding(prettyName, toCharset)
}
} else {
throw QueryExecutionErrors.invalidCharsetError(prettyName, toCharset)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (string, charset) => {
val toCharset = ctx.freshName("toCharset")
val sc = JavaCode.global(
ctx.addReferenceObj("supportedCharsets", supportedCharsets),
supportedCharsets.getClass)
s"""
String $toCharset = $charset.toString();
try {
if ($legacyCharsets || $sc.contains($toCharset.toUpperCase(java.util.Locale.ROOT))) {
${ev.value} = $string.toString().getBytes($toCharset);
} else {
throw new java.io.UnsupportedEncodingException();
}
} catch (java.io.UnsupportedEncodingException e) {
throw QueryExecutionErrors.invalidCharsetError("$prettyName", $toCharset);
}"""
})
val expr = ctx.addReferenceObj("this", this)
defineCodeGen(ctx, ev, (str, charset) => s"(byte[]) $expr.nullSafeEval($str, $charset)")
}

override protected def withNewChildrenInternal(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2741,6 +2741,14 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
"charset" -> charset))
}

def malformedCharacterCoding(functionName: String, charset: String): RuntimeException = {
new SparkRuntimeException(
errorClass = "MALFORMED_CHARACTER_CODING",
messageParameters = Map(
"function" -> toSQLId(functionName),
"charset" -> charset))
}

def invalidWriterCommitMessageError(details: String): Throwable = {
new SparkRuntimeException(
errorClass = "INVALID_WRITER_COMMIT_MESSAGE",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5002,6 +5002,14 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val LEGACY_CODING_ERROR_ACTION = buildConf("spark.sql.legacy.codingErrorAction")
.internal()
.doc("When set to true, encode/decode functions replace unmappable characters with mojibake " +
"instead of reporting coding errors.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it should be a fallback conf to ANSI.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reasons I'd like to make it independent of ANSI are:

  • Part of the implication of ANSI is Hive-incompatibility,
  • Hive also reports coding errors, so it was a mistake when we ported this from hive
  • These functions are not ANSI-defined
  • The error behaviors are also not found in ANSI

The reasons mentioned above indicate that this behavior is more of a legacy trait of Spark itself.


val LEGACY_EVAL_CURRENT_TIME = buildConf("spark.sql.legacy.earlyEvalCurrentTime")
.internal()
.doc("When set to true, evaluation and constant folding will happen for now() and " +
Expand Down Expand Up @@ -5976,6 +5984,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def legacyJavaCharsets: Boolean = getConf(SQLConf.LEGACY_JAVA_CHARSETS)

def legacyCodingErrorAction: Boolean = getConf(SQLConf.LEGACY_CODING_ERROR_ACTION)

def legacyEvalCurrentTime: Boolean = getConf(SQLConf.LEGACY_EVAL_CURRENT_TIME)

/** ********************** SQLConf functionality methods ************ */
Expand Down
Loading