Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Added an expression test
  • Loading branch information
MaxGekk committed Sep 16, 2018
commit 7206483623563d3b71a2722f54c0cb7547080f34
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,18 @@
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.csv
package org.apache.spark.sql.catalyst.csv

import java.math.BigDecimal

import scala.util.control.Exception._
import scala.util.control.Exception.allCatch

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._

private[csv] object CSVInferSchema {
object CSVInferSchema {

/**
* Similar to the JSON schema inference
Expand All @@ -44,13 +43,7 @@ private[csv] object CSVInferSchema {
val rootTypes: Array[DataType] =
tokenRDD.aggregate(startType)(inferRowType(options), mergeRowTypes)

header.zip(rootTypes).map { case (thisHeader, rootType) =>
val dType = rootType match {
case _: NullType => StringType
case other => other
}
StructField(thisHeader, dType, nullable = true)
}
toStructFields(rootTypes, header, options)
} else {
// By default fields are assumed to be StringType
header.map(fieldName => StructField(fieldName, StringType, nullable = true))
Expand All @@ -59,7 +52,20 @@ private[csv] object CSVInferSchema {
StructType(fields)
}

private def inferRowType(options: CSVOptions)
def toStructFields(
fieldTypes: Array[DataType],
header: Array[String],
options: CSVOptions): Array[StructField] = {
header.zip(fieldTypes).map { case (thisHeader, rootType) =>
val dType = rootType match {
case _: NullType => StringType
case other => other
}
StructField(thisHeader, dType, nullable = true)
}
}

def inferRowType(options: CSVOptions)
(rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = {
var i = 0
while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@

package org.apache.spark.sql.catalyst.expressions

import com.fasterxml.jackson.core.JsonFactory
import com.univocity.parsers.csv.CsvParser

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.csv._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JSONOptions}
import org.apache.spark.sql.catalyst.json.JsonInferSchema.inferField
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

/**
* Converts a CSV input string to a [[StructType]] with the specified schema.
Expand Down Expand Up @@ -107,3 +113,34 @@ case class CsvToStructs(

override def inputTypes: Seq[AbstractDataType] = StringType :: Nil
}

/**
* A function infers schema of CSV string.
*/
@ExpressionDescription(
usage = "_FUNC_(csv[, options]) - Returns schema in the DDL format of CSV string.",
examples = """
Examples:
> SELECT _FUNC_('1,abc');
struct<_c0:int,_c1:string>
""",
since = "3.0.0")
case class SchemaOfCsv(child: Expression)
extends UnaryExpression with String2StringExpression with CodegenFallback {

override def convert(v: UTF8String): UTF8String = {
val parsedOptions = new CSVOptions(Map.empty, true, "UTC")
val parser = new CsvParser(parsedOptions.asParserSettings)
val row = parser.parseLine(v.toString)

if (row != null) {
val header = row.zipWithIndex.map { case (_, index) => s"_c$index" }
val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
val fieldTypes = CSVInferSchema.inferRowType(parsedOptions)(startType, row)
val st = StructType(CSVInferSchema.toStructFields(fieldTypes, header, parsedOptions))
UTF8String.fromString(st.catalogString)
} else {
null
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,8 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P
val schemaToCompare = csvSchema.asNullable
assert(schemaToCompare == schema)
}

test("infer schema of CSV strings") {
checkEvaluation(SchemaOfCsv(Literal.create("1,abc")), "struct<_c0:int,_c1:string>")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce.Job
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat

import org.apache.spark.TaskContext

import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.{BinaryFileRDD, RDD}
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.csv.{CSVOptions, UnivocityParser}
import org.apache.spark.sql.catalyst.csv.{CSVInferSchema, CSVOptions, UnivocityParser}
import org.apache.spark.sql.catalyst.csv.CSVUtils.filterCommentAndEmpty
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
package org.apache.spark.sql.execution.datasources.csv

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.csv.CSVOptions

import org.apache.spark.sql.catalyst.csv.{CSVInferSchema, CSVOptions}
import org.apache.spark.sql.types._

class CSVInferSchemaSuite extends SparkFunSuite {
Expand Down