Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.internal.io

import org.apache.hadoop.mapreduce.TaskAttemptContext

abstract class FileCommitProtocolV2 extends FileCommitProtocol {

@deprecated("use newTaskTempFileV2", "3.1.0")
override def newTaskTempFile(
taskContext: TaskAttemptContext, dir: Option[String], ext: String): String

@deprecated("use newTaskTempFileAbsPathV2", "3.1.0")
override def newTaskTempFileAbsPath(
taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String

/**
* Notifies the commit protocol to add a new file, and gets back the full path that should be
* used. Must be called on the executors when running tasks.
*
* Note that the returned temp file may have an arbitrary path. The commit protocol only
* promises that the file will be at the location specified by the arguments after job commit.
*
* A full file path consists of the following parts:
* 1. the base path
* 2. the relative file path
*
* The "relativeFilePath" parameter specifies 2. The base path is left to the commit protocol
* implementation to decide.
*
* Important: it is the caller's responsibility to add uniquely identifying content to
* "relativeFilePath" if a task is going to write out multiple files to the same dir. The file
* commit protocol only guarantees that files written by different tasks will not conflict.
*/
def newTaskTempFileV2(taskContext: TaskAttemptContext, relativeFilePath: String): String

/**
* Similar to newTaskTempFileV2(), but allows files to committed to an absolute output location.
* Depending on the implementation, there may be weaker guarantees around adding files this way.
*
* Important: it is the caller's responsibility to add uniquely identifying content to
* "absoluteFilePath" if a task is going to write out multiple files to the same dir. The file
* commit protocol only guarantees that files written by different tasks will not conflict.
*/
def newTaskTempFileAbsPathV2(taskContext: TaskAttemptContext, absoluteFilePath: String): String
}

object FileCommitProtocolV2 {

final def getFilename(
taskContext: TaskAttemptContext, jobId: String, prefix: String, ext: String): String = {
// The file name looks like part-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003-c000.parquet
// Note that %05d does not truncate the split number, so if we have more than 100000 tasks,
// the file name is fine and won't overflow.
val split = taskContext.getTaskAttemptID.getTaskID.getId
f"${prefix}part-$split%05d-$jobId$ext"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ class HadoopMapReduceCommitProtocol(
jobId: String,
path: String,
dynamicPartitionOverwrite: Boolean = false)
extends FileCommitProtocol with Serializable with Logging {
extends FileCommitProtocolV2 with Serializable with Logging {

import FileCommitProtocol._
import FileCommitProtocolV2._

/** OutputCommitter from Hadoop is not serializable so marking it transient. */
@transient private var committer: OutputCommitter = _
Expand Down Expand Up @@ -101,9 +102,40 @@ class HadoopMapReduceCommitProtocol(
format.getOutputCommitter(context)
}

override def newTaskTempFileV2(
taskContext: TaskAttemptContext, relativeFilePath: String): String = {
val stagingDir: Path = committer match {
case _ if dynamicPartitionOverwrite =>
val dir = new Path(relativeFilePath).getParent
assert(dir != null,
"The dataset to be written must be partitioned when dynamicPartitionOverwrite is true.")
partitionPaths += dir.toString
this.stagingDir
// For FileOutputCommitter it has its own staging path called "work path".
case f: FileOutputCommitter =>
new Path(Option(f.getWorkPath).map(_.toString).getOrElse(path))
case _ => new Path(path)
}

new Path(stagingDir, relativeFilePath).toString
}

override def newTaskTempFileAbsPathV2(
taskContext: TaskAttemptContext, absoluteFilePath: String): String = {
val filename = new Path(absoluteFilePath).getName
val absOutputPath = new Path(absoluteFilePath).toString

// Include a UUID here to prevent file collisions for one task writing to different dirs.
// In principle we could include hash(absoluteDir) instead but this is simpler.
val tmpOutputPath = new Path(stagingDir, UUID.randomUUID().toString + "-" + filename).toString

addedAbsPathFiles(tmpOutputPath) = absOutputPath
tmpOutputPath
}

override def newTaskTempFile(
taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = {
val filename = getFilename(taskContext, ext)
val filename = getFilename(taskContext, jobId, "", ext)

val stagingDir: Path = committer match {
case _ if dynamicPartitionOverwrite =>
Expand All @@ -126,7 +158,7 @@ class HadoopMapReduceCommitProtocol(

override def newTaskTempFileAbsPath(
taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = {
val filename = getFilename(taskContext, ext)
val filename = getFilename(taskContext, jobId, "", ext)
val absOutputPath = new Path(absoluteDir, filename).toString

// Include a UUID here to prevent file collisions for one task writing to different dirs.
Expand All @@ -137,14 +169,6 @@ class HadoopMapReduceCommitProtocol(
tmpOutputPath
}

protected def getFilename(taskContext: TaskAttemptContext, ext: String): String = {
// The file name looks like part-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003-c000.parquet
// Note that %05d does not truncate the split number, so if we have more than 100000 tasks,
// the file name is fine and won't overflow.
val split = taskContext.getTaskAttemptID.getTaskID.getId
f"part-$split%05d-$jobId$ext"
}

override def setupJob(jobContext: JobContext): Unit = {
// Setup IDs
val jobId = SparkHadoopWriterUtils.createJobID(new Date, 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.TaskAttemptContext
import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, PathOutputCommitter, PathOutputCommitterFactory}

import org.apache.spark.internal.io.FileCommitProtocolV2._
import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol

/**
Expand Down Expand Up @@ -134,7 +135,7 @@ class PathOutputCommitProtocol(
val parent = dir.map {
d => new Path(workDir, d)
}.getOrElse(workDir)
val file = new Path(parent, getFilename(taskContext, ext))
val file = new Path(parent, getFilename(taskContext, jobId, "", ext))
logTrace(s"Creating task file $file for dir $dir and ext $ext")
file.toString
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
Expand Down Expand Up @@ -260,7 +261,7 @@ case class FileSourceScanExec(
// exposed for testing
lazy val bucketedScan: Boolean = {
if (relation.sparkSession.sessionState.conf.bucketingEnabled && relation.bucketSpec.isDefined
&& !disableBucketedScan) {
&& !disableBucketedScan && !DDLUtils.isHiveTable(relation.options.get(DDLUtils.PROVIDER))) {
val spec = relation.bucketSpec.get
val bucketColumns = spec.bucketColumnNames.flatMap(n => toAttribute(n))
bucketColumns.size == spec.bucketColumnNames.size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,9 @@ case class AlterTableSetLocationCommand(


object DDLUtils {
val PROVIDER = "provider"
val HIVE_PROVIDER = "hive"
val HIVE_VERSION = "hive_version"

def isHiveTable(table: CatalogTable): Boolean = {
isHiveTable(table.provider)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.mutable
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.TaskAttemptContext

import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.internal.io.{FileCommitProtocol, FileCommitProtocolV2}
import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
Expand Down Expand Up @@ -165,6 +165,10 @@ class DynamicPartitionDataWriter(
|WriteJobDescription: $description
""".stripMargin)

/** Flag saying whether or not to use [[FileCommitProtocolV2]]. */
private val isFileCommitProtocolV2 = committer.isInstanceOf[FileCommitProtocolV2] &&
description.bucketFileNamePrefix.isDefined

private var fileCounter: Int = _
private var recordsInFile: Long = _
private var currentPartionValues: Option[UnsafeRow] = None
Expand Down Expand Up @@ -229,11 +233,39 @@ class DynamicPartitionDataWriter(
val customPath = partDir.flatMap { dir =>
description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
}
val currentPath = if (customPath.isDefined) {
committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext)
} else {
committer.newTaskTempFile(taskAttemptContext, partDir, ext)
}

val currentPath =
if (isFileCommitProtocolV2) {
val fileNamePrefix = (description.bucketFileNamePrefix, bucketId) match {
case (Some(prefix), Some(id)) => Some(prefix(id))
case _ => None
}

(committer, fileNamePrefix) match {
case (c: FileCommitProtocolV2, Some(prefix)) =>
val fileName = FileCommitProtocolV2.getFilename(
taskAttemptContext, description.uuid, prefix, ext)
if (customPath.isDefined) {
val absoluteFilePath = new Path(customPath.get, fileName).toString
c.newTaskTempFileAbsPathV2(taskAttemptContext, absoluteFilePath)
} else {
val relativeFilePath = partDir match {
case Some(dir) => new Path(dir, fileName).toString
case None => fileName
}
c.newTaskTempFileV2(taskAttemptContext, relativeFilePath)
}
case c =>
throw new IllegalArgumentException(
s"DynamicPartitionDataWriter should not take $c as the file commit protocol")
}
} else {
if (customPath.isDefined) {
committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext)
} else {
committer.newTaskTempFile(taskAttemptContext, partDir, ext)
}
}

currentWriter = description.outputWriterFactory.newInstance(
path = currentPath,
Expand Down Expand Up @@ -286,6 +318,7 @@ class WriteJobDescription(
val dataColumns: Seq[Attribute],
val partitionColumns: Seq[Attribute],
val bucketIdExpression: Option[Expression],
val bucketFileNamePrefix: Option[Int => String],
val path: String,
val customPartitionLocations: Map[TablePartitionSpec, String],
val maxRecordsPerFile: Long,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StringType
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -69,6 +70,17 @@ object FileFormatWriter extends Logging {
}
}

/**
* A function that gets bucket file name prefix given bucket id.
* The new bucket file name is following Hive and Presto conversion, so this makes sure
* Hive bucketed table written by Spark, can be read by other SQL engines like Hive and Presto.
*
* Hive bucketing naming: `org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`.
* Presto bucketing naming (prestosql here):
* `io.prestosql.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`.
*/
def compatibleBucketFileNamePrefix(bucketId: Int): String = f"$bucketId%05d_0_"

/**
* Basic work flow of this command is:
* 1. Driver side setup, including output committer initialization and data source specific
Expand Down Expand Up @@ -113,12 +125,32 @@ object FileFormatWriter extends Logging {
}
val empty2NullPlan = if (needConvert) ProjectExec(projectList, plan) else plan

val bucketIdExpression = bucketSpec.map { spec =>
var bucketFileNamePrefix: Option[Int => String] = None
val bucketIdExpression = bucketSpec.flatMap { spec =>
val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
// Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
// guarantee the data distribution is same between shuffle and bucketed data source, which
// enables us to only shuffle one side when join a bucketed table and a normal one.
HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
if (DDLUtils.isHiveTable(options.get(DDLUtils.PROVIDER))) {
val hiveVersion = options.getOrElse(DDLUtils.HIVE_VERSION, "")
val hiveVersion012 = Seq("0.", "1.", "2.")
if (hiveVersion012.exists(hiveVersion.startsWith)) {
bucketFileNamePrefix = Some(compatibleBucketFileNamePrefix)
// For Hive bucketed table, use `HiveHash` and bitwise-and as our bucket id expression.
// Without the extra bitwise-and operation, we can get wrong bucket id when hash value
// of columns is negative. See Hive implementation in
// `org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#getBucketNumber()`.
val bucketId = HiveHash(bucketColumns)
val bucketIdAfterAnd = BitwiseAnd(bucketId, Literal(Int.MaxValue))
Some(Pmod(bucketIdAfterAnd, Literal(spec.numBuckets)))
} else {
// TODO(SPARK-32710/32711): Write Hive 3.x ORC/Parquet bucketed table
None
}
} else {
// For Spark data source bucketed table, use `HashPartitioning.partitionIdExpression`
// as our bucket id expression, so that we can guarantee the data distribution is same
// between shuffle and bucketed data source, which enables us to only shuffle one side
// when join a bucketed table and a normal one.
Some(HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression)
}
}
val sortColumns = bucketSpec.toSeq.flatMap {
spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
Expand All @@ -140,6 +172,7 @@ object FileFormatWriter extends Logging {
dataColumns = dataColumns,
partitionColumns = partitionColumns,
bucketIdExpression = bucketIdExpression,
bucketFileNamePrefix = bucketFileNamePrefix,
path = outputSpec.outputPath,
customPartitionLocations = outputSpec.customPartitionLocations,
maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ abstract class FileWriteBuilder(
dataColumns = allColumns,
partitionColumns = Seq.empty,
bucketIdExpression = None,
bucketFileNamePrefix = None,
path = pathName,
customPartitionLocations = Map.empty,
maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong)
Expand Down
Loading