Skip to content
Closed
Prev Previous commit
Next Next commit
Support decimal precision/scale in Hive metastore
  • Loading branch information
mateiz committed Nov 1, 2014
commit b28933d1ce69ee6ec531ec8569020a90d2598e5b
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,8 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
private[sql] val ordering = Decimal.DecimalIsFractional
private[sql] val asIntegral = Decimal.DecimalAsIfIntegral

def simpleString: String = precisionInfo match {
case Some(PrecisionInfo(precision, scale)) => s"decimal($precision, $scale})"
override def typeName: String = precisionInfo match {
case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)"
case None => "decimal"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import java.io.{BufferedReader, File, InputStreamReader, PrintStream}
import java.sql.{Date, Timestamp}
import java.util.{ArrayList => JArrayList}

import org.apache.hadoop.hive.common.`type`.HiveDecimal
import org.apache.spark.sql.catalyst.types.DecimalType
import org.apache.spark.sql.catalyst.types.decimal.Decimal

import scala.collection.JavaConversions._
import scala.language.implicitConversions
Expand Down Expand Up @@ -374,7 +376,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
case (d: Date, DateType) => new DateWritable(d).toString
case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString
case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8")
case (decimal, DecimalType()) => decimal.toString
case (decimal: Decimal, DecimalType()) => // Hive strips trailing zeros so use its toString
HiveShim.createDecimal(decimal.toBigDecimal.underlying()).toString
case (other, tpe) if primitiveTypes contains tpe => other.toString
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.hive

import org.apache.hadoop.hive.common.`type`.HiveDecimal
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory
import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfoFactory}
import org.apache.hadoop.hive.serde2.objectinspector._
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector._
import org.apache.hadoop.hive.serde2.objectinspector.primitive._
Expand Down Expand Up @@ -91,7 +91,7 @@ private[hive] trait HiveInspectors {
case hvoi: HiveVarcharObjectInspector =>
if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue
case hdoi: HiveDecimalObjectInspector =>
if (data == null) null else Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue())
if (data == null) null else HiveShim.toCatalystDecimal(hdoi, data)
// org.apache.hadoop.hive.serde2.io.TimestampWritable.set will reset current time object
// if next timestamp is null, so Timestamp object is cloned
case ti: TimestampObjectInspector => ti.getPrimitiveJavaObject(data).clone()
Expand Down Expand Up @@ -281,8 +281,10 @@ private[hive] trait HiveInspectors {
case _: JavaFloatObjectInspector => FloatType
case _: WritableBinaryObjectInspector => BinaryType
case _: JavaBinaryObjectInspector => BinaryType
case _: WritableHiveDecimalObjectInspector => DecimalType.Unlimited // TODO: fixed precision
case _: JavaHiveDecimalObjectInspector => DecimalType.Unlimited // TODO: fixed precision
case w: WritableHiveDecimalObjectInspector =>
HiveShim.decimalTypeInfoToCatalyst(w.getTypeInfo.asInstanceOf[DecimalTypeInfo])
case j: JavaHiveDecimalObjectInspector =>
HiveShim.decimalTypeInfoToCatalyst(j.getTypeInfo.asInstanceOf[DecimalTypeInfo])
case _: WritableDateObjectInspector => DateType
case _: JavaDateObjectInspector => DateType
case _: WritableTimestampObjectInspector => TimestampType
Expand Down Expand Up @@ -311,7 +313,7 @@ private[hive] trait HiveInspectors {
case LongType => longTypeInfo
case ShortType => shortTypeInfo
case StringType => stringTypeInfo
case DecimalType() => decimalTypeInfo // TODO: fixed precision
case d: DecimalType => HiveShim.decimalTypeInfo(d)
case DateType => dateTypeInfo
case TimestampType => timestampTypeInfo
case NullType => voidTypeInfo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive
import java.io.IOException
import java.util.{List => JList}

import scala.util.matching.Regex
import scala.util.parsing.combinator.RegexParsers

import org.apache.hadoop.util.ReflectionUtils
Expand Down Expand Up @@ -321,11 +322,18 @@ object HiveMetastoreTypes extends RegexParsers {
"bigint" ^^^ LongType |
"binary" ^^^ BinaryType |
"boolean" ^^^ BooleanType |
HiveShim.metastoreDecimal ^^^ DecimalType.Unlimited | // TODO: fixed precision
fixedDecimalType | // Hive 0.13+ decimal with precision/scale
"decimal" ^^^ DecimalType.Unlimited | // Hive 0.12 decimal with no precision/scale
"date" ^^^ DateType |
"timestamp" ^^^ TimestampType |
"varchar\\((\\d+)\\)".r ^^^ StringType

protected lazy val fixedDecimalType: Parser[DataType] =
("decimal" ~> "(" ~> "\\d+".r) ~ ("," ~> "\\d+".r <~ ")") ^^ {
case precision ~ scale =>
DecimalType(precision.toInt, scale.toInt)
}

protected lazy val arrayType: Parser[DataType] =
"array" ~> "<" ~> dataType <~ ">" ^^ {
case tpe => ArrayType(tpe)
Expand Down Expand Up @@ -373,7 +381,7 @@ object HiveMetastoreTypes extends RegexParsers {
case BinaryType => "binary"
case BooleanType => "boolean"
case DateType => "date"
case DecimalType.Unlimited => "decimal" // TODO: fixed precision
case d: DecimalType => HiveShim.decimalMetastoreString(d)
case TimestampType => "timestamp"
case NullType => "void"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@ private[hive] object HiveQl {
}

protected def nodeToDataType(node: Node): DataType = node match {
case Token("TOK_DECIMAL", precision :: scale :: Nil) =>
DecimalType(precision.getText.toInt, scale.getText.toInt)
case Token("TOK_DECIMAL", precision :: Nil) =>
DecimalType(precision.getText.toInt, 0)
case Token("TOK_DECIMAL", Nil) => DecimalType.Unlimited
case Token("TOK_BIGINT", Nil) => LongType
case Token("TOK_INT", Nil) => IntegerType
Expand Down Expand Up @@ -943,6 +947,10 @@ private[hive] object HiveQl {
Cast(nodeToExpr(arg), BinaryType)
case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) =>
Cast(nodeToExpr(arg), BooleanType)
case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: scale :: nil) :: arg :: Nil) =>
Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, scale.getText.toInt))
case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) =>
Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, 0))
case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) =>
Cast(nodeToExpr(arg), DecimalType.Unlimited)
case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,24 @@ import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc}
import org.apache.hadoop.hive.ql.processors._
import org.apache.hadoop.hive.ql.stats.StatsSetupConst
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory}
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector
import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo
import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils}
import org.apache.hadoop.hive.serde2.{io => hiveIo}
import org.apache.hadoop.{io => hadoopIo}
import org.apache.hadoop.mapred.InputFormat
import org.apache.spark.sql.catalyst.types.decimal.Decimal
import scala.collection.JavaConversions._
import scala.language.implicitConversions

import org.apache.spark.sql.catalyst.types.DecimalType

/**
* A compatibility layer for interacting with Hive version 0.12.0.
*/
private[hive] object HiveShim {
val version = "0.12.0"
val metastoreDecimal = "decimal"

def getTableDesc(
serdeClass: Class[_ <: Deserializer],
Expand Down Expand Up @@ -149,6 +152,16 @@ private[hive] object HiveShim {
def setLocation(tbl: Table, crtTbl: CreateTableDesc): Unit = {
tbl.setDataLocation(new Path(crtTbl.getLocation()).toUri())
}

def decimalMetastoreString(decimalType: DecimalType): String = "decimal"

def decimalTypeInfo(decimalType: DecimalType): DecimalTypeInfo = new DecimalTypeInfo()

def decimalTypeInfoToCatalyst(info: DecimalTypeInfo): DecimalType = DecimalType.Unlimited

def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = {
Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue())
}
}

class ShimFileSinkDesc(var dir: String, var tableInfo: TableDesc, var compressed: Boolean)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,17 @@ import org.apache.hadoop.hive.ql.Context
import org.apache.hadoop.hive.ql.metadata.{Table, Hive, Partition}
import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc}
import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory
import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfoFactory}
import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer}
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory}
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector
import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils}
import org.apache.hadoop.hive.serde2.{io => hiveIo}
import org.apache.hadoop.{io => hadoopIo}
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.types.DecimalType
import org.apache.spark.sql.catalyst.types.decimal.Decimal

import scala.collection.JavaConversions._
import scala.language.implicitConversions
Expand All @@ -47,11 +49,6 @@ import scala.language.implicitConversions
*/
private[hive] object HiveShim {
val version = "0.13.1"
/*
* TODO: hive-0.13 support DECIMAL(precision, scale), DECIMAL in hive-0.12 is actually DECIMAL(38,unbounded)
* Full support of new decimal feature need to be fixed in seperate PR.
*/
val metastoreDecimal = "decimal\\((\\d+),(\\d+)\\)".r

def getTableDesc(
serdeClass: Class[_ <: Deserializer],
Expand Down Expand Up @@ -197,6 +194,29 @@ private[hive] object HiveShim {
f.setDestTableId(w.destTableId)
f
}

// Precision and scale to pass for unlimited decimals; these are the same as the precision and
// scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs)
private val UNLIMITED_DECIMAL_PRECISION = 38
private val UNLIMITED_DECIMAL_SCALE = 18

def decimalMetastoreString(decimalType: DecimalType): String = decimalType match {
case DecimalType.Fixed(precision, scale) => s"decimal($precision,$scale)"
case _ => s"decimal($UNLIMITED_DECIMAL_PRECISION,$UNLIMITED_DECIMAL_SCALE)"
}

def decimalTypeInfo(decimalType: DecimalType): DecimalTypeInfo = decimalType match {
case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale)
case _ => new DecimalTypeInfo(UNLIMITED_DECIMAL_PRECISION, UNLIMITED_DECIMAL_SCALE)
}

def decimalTypeInfoToCatalyst(info: DecimalTypeInfo): DecimalType = {
DecimalType(info.precision(), info.scale())
}

def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = {
Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale())
}
}

/*
Expand Down