diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index c0b8b270bab1..7891d1c047ea 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,6 +5,7 @@ org.apache.spark.sql.execution.datasources.noop.NoopDataSource org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.v2.parquet.ParquetDataSourceV2 org.apache.spark.sql.execution.datasources.v2.text.TextDataSourceV2 +org.apache.spark.sql.execution.datasources.v2.state.StateDataSourceV2 org.apache.spark.sql.execution.streaming.ConsoleSinkProvider org.apache.spark.sql.execution.streaming.sources.RateStreamProvider org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/SchemaUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/SchemaUtil.scala new file mode 100644 index 000000000000..5dcbdbb89e90 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/SchemaUtil.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import org.apache.spark.sql.types.{DataType, StructType} + +object SchemaUtil { + def getSchemaAsDataType(schema: StructType, fieldName: String): DataType = { + schema(schema.getFieldIndex(fieldName).get).dataType + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceV2.scala new file mode 100644 index 000000000000..082724a4ebcd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceV2.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import java.util + +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.connector.catalog.{Table, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.execution.streaming.CommitLog +import org.apache.spark.sql.execution.streaming.state.{StateSchemaFileManager, StateStore, StateStoreId} +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class StateDataSourceV2 extends TableProvider with DataSourceRegister { + + import StateDataSourceV2._ + + lazy val session: SparkSession = SparkSession.active + + override def shortName(): String = "state" + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + val checkpointLocation = Option(properties.get(PARAM_CHECKPOINT_LOCATION)).orElse { + throw new AnalysisException(s"'$PARAM_CHECKPOINT_LOCATION' must be specified.") + }.get + + val version = Option(properties.get(PARAM_VERSION)).map(_.toLong).orElse { + Some(getLastCommittedBatch(checkpointLocation)) + }.get + + val operatorId = Option(properties.get(PARAM_OPERATOR_ID)).map(_.toInt) + .orElse(Some(0)).get + + val storeName = Option(properties.get(PARAM_STORE_NAME)) + .orElse(Some(StateStoreId.DEFAULT_STORE_NAME)).get + + val stateCheckpointLocation = new Path(checkpointLocation, "state") + new StateTable(session, schema, stateCheckpointLocation.toString, version, operatorId, + storeName) + } + + override def inferSchema(options: CaseInsensitiveStringMap): StructType = { + val checkpointLocation = Option(options.get(PARAM_CHECKPOINT_LOCATION)).orElse { + throw new AnalysisException(s"'$PARAM_CHECKPOINT_LOCATION' must be specified.") + }.get + + val operatorId = Option(options.get(PARAM_OPERATOR_ID)).map(_.toInt) + .orElse(Some(0)).get + + val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA + val storeName = Option(options.get(PARAM_STORE_NAME)) + .orElse(Some(StateStoreId.DEFAULT_STORE_NAME)).get + + val stateCheckpointLocation = new Path(checkpointLocation, "state") + val storeId = new StateStoreId(stateCheckpointLocation.toString, operatorId, partitionId, + storeName) + val manager = new StateSchemaFileManager(storeId, session.sessionState.newHadoopConf()) + if (manager.fileExist()) { + val (keySchema, valueSchema) = manager.readSchema() + new StructType() + .add("key", keySchema) + .add("value", valueSchema) + } else { + throw new UnsupportedOperationException("Schema information file doesn't exist - schema " + + "should be explicitly specified.") + } + } + + private def getLastCommittedBatch(checkpointLocation: String): Long = { + val commitLog = new CommitLog(session, new Path(checkpointLocation, "commits").toString) + val lastCommittedBatchId = commitLog.getLatest() match { + case Some((lastId, _)) => lastId + case None => -1 + } + + if (lastCommittedBatchId < 0) { + throw new AnalysisException("No committed batch found.") + } + + lastCommittedBatchId.toLong + 1 + } + + override def supportsExternalMetadata(): Boolean = true +} + +object StateDataSourceV2 { + val PARAM_CHECKPOINT_LOCATION = "checkpointLocation" + val PARAM_VERSION = "version" + val PARAM_OPERATOR_ID = "operatorId" + val PARAM_STORE_NAME = "storeName" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala new file mode 100644 index 000000000000..9c20d8e0ca2e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProviderId} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +class StatePartitionReader( + storeConf: StateStoreConf, + hadoopConf: SerializableConfiguration, + partition: StateStoreInputPartition, + schema: StructType) extends PartitionReader[InternalRow] { + + private val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] + private val valueSchema = SchemaUtil.getSchemaAsDataType(schema, "value").asInstanceOf[StructType] + + private lazy val iter = { + val stateStoreId = StateStoreId(partition.stateCheckpointRootLocation, + partition.operatorId, partition.partition, partition.storeName) + val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId) + + val store = StateStore.getReadOnly(stateStoreProviderId, keySchema, valueSchema, + indexOrdinal = None, version = partition.version, storeConf = storeConf, + hadoopConf = hadoopConf.value) + + store.iterator().map(pair => unifyStateRowPair((pair.key, pair.value))) + } + + private var current: InternalRow = _ + + override def next(): Boolean = { + if (iter.hasNext) { + current = iter.next() + true + } else { + current = null + false + } + } + + override def get(): InternalRow = current + + override def close(): Unit = { + current = null + } + + private def unifyStateRowPair(pair: (UnsafeRow, UnsafeRow)): InternalRow = { + val row = new GenericInternalRow(2) + row.update(0, pair._1) + row.update(1, pair._2) + row + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderFactory.scala new file mode 100644 index 000000000000..f4cab378aa8e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderFactory.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} +import org.apache.spark.sql.execution.streaming.state.StateStoreConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +class StatePartitionReaderFactory( + storeConf: StateStoreConf, + hadoopConf: SerializableConfiguration, + schema: StructType) extends PartitionReaderFactory { + + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val part = partition match { + case p: StateStoreInputPartition => p + case e => throw new IllegalStateException("Expected StateStorePartition but other type of " + + s"partition passed - $e") + } + + new StatePartitionReader(storeConf, hadoopConf, part, schema) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScan.scala new file mode 100644 index 000000000000..9f4c9b45ba5f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScan.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import java.util.UUID + +import scala.util.Try + +import org.apache.hadoop.fs.{Path, PathFilter} + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan, ScanBuilder} +import org.apache.spark.sql.execution.streaming.state.StateStoreConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +class StateScanBuilder( + session: SparkSession, + schema: StructType, + stateCheckpointRootLocation: String, + version: Long, + operatorId: Long, + storeName: String) extends ScanBuilder { + override def build(): Scan = new StateScan(session, schema, stateCheckpointRootLocation, + version, operatorId, storeName) +} + +class StateStoreInputPartition( + val partition: Int, + val queryId: UUID, + val stateCheckpointRootLocation: String, + val version: Long, + val operatorId: Long, + val storeName: String) extends InputPartition + +class StateScan( + session: SparkSession, + schema: StructType, + stateCheckpointRootLocation: String, + version: Long, + operatorId: Long, + storeName: String) extends Scan with Batch { + + // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it + private val hadoopConfBroadcast = session.sparkContext.broadcast( + new SerializableConfiguration(session.sessionState.newHadoopConf())) + + private val resolvedCpLocation = { + val checkpointPath = new Path(stateCheckpointRootLocation) + val fs = checkpointPath.getFileSystem(session.sessionState.newHadoopConf()) + checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toUri.toString + } + + override def readSchema(): StructType = schema + + override def planInputPartitions(): Array[InputPartition] = { + val fs = stateCheckpointPartitionsLocation.getFileSystem(hadoopConfBroadcast.value.value) + val partitions = fs.listStatus(stateCheckpointPartitionsLocation, new PathFilter() { + override def accept(path: Path): Boolean = { + fs.isDirectory(path) && Try(path.getName.toInt).isSuccess && path.getName.toInt >= 0 + } + }) + + if (partitions.headOption.isEmpty) { + Array.empty[InputPartition] + } else { + // just a dummy query id because we are actually not running streaming query + val queryId = UUID.randomUUID() + + val partitionsSorted = partitions.sortBy(fs => fs.getPath.getName.toInt) + val partitionNums = partitionsSorted.map(_.getPath.getName.toInt) + // assuming no same number - they're directories hence no same name + val head = partitionNums.head + val tail = partitionNums(partitionNums.length - 1) + assert(head == 0, "Partition should start with 0") + assert((tail - head + 1) == partitionNums.length, + s"No continuous partitions in state: $partitionNums") + + partitionNums.map { + pn => new StateStoreInputPartition(pn, queryId, stateCheckpointRootLocation, + version, operatorId, storeName) + }.toArray + } + } + + override def createReaderFactory(): PartitionReaderFactory = + new StatePartitionReaderFactory(new StateStoreConf(session.sessionState.conf), + hadoopConfBroadcast.value, schema) + + override def toBatch: Batch = this + + override def description(): String = s"StateScan [stateCpLocation=$stateCheckpointRootLocation]" + + s"[version=$version][operatorId=$operatorId][storeName=$storeName]" + + private def stateCheckpointPartitionsLocation: Path = { + new Path(resolvedCpLocation, s"$operatorId") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala new file mode 100644 index 000000000000..97534625f776 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import java.util + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability} +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class StateTable( + session: SparkSession, + override val schema: StructType, + stateCheckpointLocation: String, + version: Long, + operatorId: Int, + storeName: String) + extends Table with SupportsRead { + + import StateTable._ + + if (!isValidSchema(schema)) { + throw new AnalysisException("The fields of schema should be 'key' and 'value', " + + "and each field should have corresponding fields (they should be a StructType)") + } + + override def name(): String = + s"state-table-cp-$stateCheckpointLocation-ver-$version-operator-$operatorId-store-$storeName" + + override def capabilities(): util.Set[TableCapability] = CAPABILITY + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = + new StateScanBuilder(session, schema, stateCheckpointLocation, version, operatorId, storeName) + + override def properties(): util.Map[String, String] = Map( + "stateCheckpointLocation" -> stateCheckpointLocation, + "version" -> version.toString, + "operatorId" -> operatorId.toString, + "storeName" -> storeName).asJava + + private def isValidSchema(schema: StructType): Boolean = { + if (schema.fieldNames.toSeq != Seq("key", "value")) { + false + } else if (!SchemaUtil.getSchemaAsDataType(schema, "key").isInstanceOf[StructType]) { + false + } else if (!SchemaUtil.getSchemaAsDataType(schema, "value").isInstanceOf[StructType]) { + false + } else { + true + } + } +} + +object StateTable { + private val CAPABILITY = Set(TableCapability.BATCH_READ).asJava +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index 4ac12c089c0d..4fde8a1d16ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -18,10 +18,8 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging -import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MetadataVersionUtil} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} @@ -31,16 +29,12 @@ class StateSchemaCompatibilityChecker( providerId: StateStoreProviderId, hadoopConf: Configuration) extends Logging { - private val storeCpLocation = providerId.storeId.storeCheckpointLocation() - private val fm = CheckpointFileManager.create(storeCpLocation, hadoopConf) - private val schemaFileLocation = schemaFile(storeCpLocation) - - fm.mkdirs(schemaFileLocation.getParent) + private val stateFileManager = new StateSchemaFileManager(providerId.storeId, hadoopConf) def check(keySchema: StructType, valueSchema: StructType): Unit = { - if (fm.exists(schemaFileLocation)) { + if (stateFileManager.fileExist()) { logDebug(s"Schema file for provider $providerId exists. Comparing with provided schema.") - val (storedKeySchema, storedValueSchema) = readSchemaFile() + val (storedKeySchema, storedValueSchema) = stateFileManager.readSchema() if (storedKeySchema.equals(keySchema) && storedValueSchema.equals(valueSchema)) { // schema is exactly same } else if (!schemasCompatible(storedKeySchema, keySchema) || @@ -64,55 +58,10 @@ class StateSchemaCompatibilityChecker( } else { // schema doesn't exist, create one now logDebug(s"Schema file for provider $providerId doesn't exist. Creating one.") - createSchemaFile(keySchema, valueSchema) + stateFileManager.writeSchema(keySchema, valueSchema) } } private def schemasCompatible(storedSchema: StructType, schema: StructType): Boolean = DataType.equalsIgnoreNameAndCompatibleNullability(storedSchema, schema) - - private def readSchemaFile(): (StructType, StructType) = { - val inStream = fm.open(schemaFileLocation) - try { - val versionStr = inStream.readUTF() - // Currently we only support version 1, which we can simplify the version validation and - // the parse logic. - val version = MetadataVersionUtil.validateVersion(versionStr, - StateSchemaCompatibilityChecker.VERSION) - require(version == 1) - - val keySchemaStr = inStream.readUTF() - val valueSchemaStr = inStream.readUTF() - - (StructType.fromString(keySchemaStr), StructType.fromString(valueSchemaStr)) - } catch { - case e: Throwable => - logError(s"Fail to read schema file from $schemaFileLocation", e) - throw e - } finally { - inStream.close() - } - } - - private def createSchemaFile(keySchema: StructType, valueSchema: StructType): Unit = { - val outStream = fm.createAtomic(schemaFileLocation, overwriteIfPossible = false) - try { - outStream.writeUTF(s"v${StateSchemaCompatibilityChecker.VERSION}") - outStream.writeUTF(keySchema.json) - outStream.writeUTF(valueSchema.json) - outStream.close() - } catch { - case e: Throwable => - logError(s"Fail to write schema file to $schemaFileLocation", e) - outStream.cancel() - throw e - } - } - - private def schemaFile(storeCpLocation: Path): Path = - new Path(new Path(storeCpLocation, "_metadata"), "schema") -} - -object StateSchemaCompatibilityChecker { - val VERSION = 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaFileManager.scala new file mode 100644 index 000000000000..712fc24b5af5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaFileManager.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming.state + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MetadataVersionUtil} +import org.apache.spark.sql.types.StructType + +class StateSchemaFileManager(storeId: StateStoreId, hadoopConf: Configuration) extends Logging { + private val storeCpLocation = storeId.storeCheckpointLocation() + private val fm = CheckpointFileManager.create(storeCpLocation, hadoopConf) + private val schemaFileLocation = schemaFile(storeCpLocation) + + def fileExist(): Boolean = fm.exists(schemaFileLocation) + + def readSchema(): (StructType, StructType) = { + val inStream = fm.open(schemaFileLocation) + try { + val versionStr = inStream.readUTF() + // Currently we only support version 1, which we can simplify the version validation and + // the parse logic. + val version = MetadataVersionUtil.validateVersion(versionStr, + StateSchemaFileManager.VERSION) + require(version == 1) + + val keySchemaStr = inStream.readUTF() + val valueSchemaStr = inStream.readUTF() + + (StructType.fromString(keySchemaStr), StructType.fromString(valueSchemaStr)) + } catch { + case e: Throwable => + logError(s"Fail to read schema file from $schemaFileLocation", e) + throw e + } finally { + inStream.close() + } + } + + def writeSchema(keySchema: StructType, valueSchema: StructType): Unit = { + if (!fm.exists(schemaFileLocation.getParent)) { + fm.mkdirs(schemaFileLocation.getParent) + } + + val outStream = fm.createAtomic(schemaFileLocation, overwriteIfPossible = false) + try { + outStream.writeUTF(s"v${StateSchemaFileManager.VERSION}") + outStream.writeUTF(keySchema.json) + outStream.writeUTF(valueSchema.json) + outStream.close() + } catch { + case e: Throwable => + logError(s"Fail to write schema file to $schemaFileLocation", e) + outStream.cancel() + throw e + } + } + + private def schemaFile(storeCpLocation: Path): Path = + new Path(new Path(storeCpLocation, "_metadata"), "schema") +} + +object StateSchemaFileManager { + val VERSION = 1 +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateDataSourceV2ReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateDataSourceV2ReadSuite.scala new file mode 100644 index 000000000000..d4bdc61cfd95 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateDataSourceV2ReadSuite.scala @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.state + +import org.scalatest.{Assertions, BeforeAndAfterAll} + +import org.apache.spark.sql.{Encoders, Row} +import org.apache.spark.sql.execution.datasources.v2.state.StateDataSourceV2 +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.internal.SQLConf + +class StateDataSourceV2ReadSuite + extends StateStoreTestBase + with BeforeAndAfterAll + with Assertions { + + override def afterAll(): Unit = { + super.afterAll() + StateStore.stop() + } + + test("simple aggregation, state ver 1, infer schema = false") { + testStreamingAggregation(1, inferSchema = false) + } + + test("simple aggregation, state ver 1, infer schema = true") { + testStreamingAggregation(1, inferSchema = true) + } + + test("simple aggregation, state ver 2, infer schema = false") { + testStreamingAggregation(2, inferSchema = false) + } + + test("simple aggregation, state ver 2, infer schema = true") { + testStreamingAggregation(2, inferSchema = true) + } + + test("composite key aggregation, state ver 1, infer schema = false") { + testStreamingAggregationWithCompositeKey(1, inferSchema = false) + } + + test("composite key aggregation, state ver 1, infer schema = true") { + testStreamingAggregationWithCompositeKey(1, inferSchema = true) + } + + test("composite key aggregation, state ver 2, infer schema = false") { + testStreamingAggregationWithCompositeKey(2, inferSchema = false) + } + + test("composite key aggregation, ver 2, infer schema = true") { + testStreamingAggregationWithCompositeKey(2, inferSchema = true) + } + + private def testStreamingAggregation(stateVersion: Int, inferSchema: Boolean): Unit = { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> stateVersion.toString) { + withTempDir { tempDir => + runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) + + val operatorId = 0 + val batchId = 2 + + val stateReader = spark.read + .format("state") + .option(StateDataSourceV2.PARAM_CHECKPOINT_LOCATION, tempDir.getAbsolutePath) + // explicitly specifying version and operator ID to test out the functionality + .option(StateDataSourceV2.PARAM_VERSION, batchId + 1) + .option(StateDataSourceV2.PARAM_OPERATOR_ID, operatorId) + + val stateReadDf = if (inferSchema) { + stateReader.load() + } else { + val stateSchema = getSchemaForLargeDataStreamingAggregationQuery(stateVersion) + stateReader.schema(stateSchema).load() + } + + logInfo(s"Schema: ${stateReadDf.schema.treeString}") + + val resultDf = if (inferSchema) { + stateReadDf + .selectExpr("key.groupKey AS key_groupKey", "value.count AS value_cnt", + "value.sum AS value_sum", "value.max AS value_max", "value.min AS value_min") + } else { + stateReadDf + .selectExpr("key.groupKey AS key_groupKey", "value.cnt AS value_cnt", + "value.sum AS value_sum", "value.max AS value_max", "value.min AS value_min") + } + + checkAnswer( + resultDf, + Seq( + Row(0, 5, 60, 30, 0), // 0, 10, 20, 30 + Row(1, 5, 65, 31, 1), // 1, 11, 21, 31 + Row(2, 5, 70, 32, 2), // 2, 12, 22, 32 + Row(3, 4, 72, 33, 3), // 3, 13, 23, 33 + Row(4, 4, 76, 34, 4), // 4, 14, 24, 34 + Row(5, 4, 80, 35, 5), // 5, 15, 25, 35 + Row(6, 4, 84, 36, 6), // 6, 16, 26, 36 + Row(7, 4, 88, 37, 7), // 7, 17, 27, 37 + Row(8, 4, 92, 38, 8), // 8, 18, 28, 38 + Row(9, 4, 96, 39, 9) // 9, 19, 29, 39 + ) + ) + } + } + } + + private def testStreamingAggregationWithCompositeKey( + stateVersion: Int, + inferSchema: Boolean): Unit = { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> stateVersion.toString) { + withTempDir { tempDir => + runCompositeKeyStreamingAggregationQuery(tempDir.getAbsolutePath) + + val stateReader = spark.read + .format("state") + .option(StateDataSourceV2.PARAM_CHECKPOINT_LOCATION, tempDir.getAbsolutePath) + // skip version and operator ID to test out functionalities + + val stateReadDf = if (inferSchema) { + stateReader.load() + } else { + val stateSchema = getSchemaForCompositeKeyStreamingAggregationQuery(stateVersion) + stateReader.schema(stateSchema).load() + } + + logInfo(s"Schema: ${stateReadDf.schema.treeString}") + + val resultDf = if (inferSchema) { + stateReadDf + .selectExpr("key.groupKey AS key_groupKey", "key.fruit AS key_fruit", + "value.count AS value_cnt", "value.sum AS value_sum", "value.max AS value_max", + "value.min AS value_min") + } else { + stateReadDf + .selectExpr("key.groupKey AS key_groupKey", "key.fruit AS key_fruit", + "value.cnt AS value_cnt", "value.sum AS value_sum", "value.max AS value_max", + "value.min AS value_min") + } + + checkAnswer( + resultDf, + Seq( + Row(0, "Apple", 2, 6, 6, 0), + Row(1, "Banana", 3, 9, 7, 1), + Row(0, "Strawberry", 3, 12, 8, 2), + Row(1, "Apple", 3, 15, 9, 3), + Row(0, "Banana", 2, 14, 10, 4), + Row(1, "Strawberry", 1, 5, 5, 5) + ) + ) + } + } + } + + test("flatMapGroupsWithState, state ver 1, infer schema = false") { + testFlatMapGroupsWithState(1, inferSchema = false) + } + + test("flatMapGroupsWithState, state ver 1, infer schema = true") { + testFlatMapGroupsWithState(1, inferSchema = true) + } + + test("flatMapGroupsWithState, state ver 2, infer schema = false") { + testFlatMapGroupsWithState(2, inferSchema = false) + } + + test("flatMapGroupsWithState, state ver 2, infer schema = true") { + testFlatMapGroupsWithState(2, inferSchema = true) + } + + private def testFlatMapGroupsWithState(stateVersion: Int, inferSchema: Boolean): Unit = { + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> stateVersion.toString) { + withTempDir { tempDir => + runFlatMapGroupsWithStateQuery(tempDir.getAbsolutePath) + + val stateReader = spark.read + .format("state") + .option(StateDataSourceV2.PARAM_CHECKPOINT_LOCATION, tempDir.getAbsolutePath) + + val stateReadDf = if (inferSchema) { + stateReader.load() + } else { + val stateSchema = getSchemaForFlatMapGroupsWithStateQuery(stateVersion) + stateReader.schema(stateSchema).load() + } + + val resultDf = if (stateVersion == 1) { + stateReadDf + .selectExpr("key.value AS key_value", "value.numEvents AS value_numEvents", + "value.startTimestampMs AS value_startTimestampMs", + "value.endTimestampMs AS value_endTimestampMs", + "value.timeoutTimestamp AS value_timeoutTimestamp") + } else { // stateVersion == 2 + stateReadDf + .selectExpr("key.value AS key_value", "value.groupState.numEvents AS value_numEvents", + "value.groupState.startTimestampMs AS value_startTimestampMs", + "value.groupState.endTimestampMs AS value_endTimestampMs", + "value.timeoutTimestamp AS value_timeoutTimestamp") + } + + checkAnswer( + resultDf, + Seq( + Row("hello", 4, 1000, 4000, 12000), + Row("world", 2, 1000, 3000, 12000), + Row("scala", 2, 2000, 4000, 12000) + ) + ) + + // try to read the value via case class provided in actual query + implicit val encoder = Encoders.product[SessionInfo] + val df = if (stateVersion == 1) { + stateReadDf.selectExpr("value.*").drop("timeoutTimestamp").as[SessionInfo] + } else { // state version == 2 + stateReadDf.selectExpr("value.groupState.*").as[SessionInfo] + } + + val expected = Array( + SessionInfo(4, 1000, 4000), + SessionInfo(2, 1000, 3000), + SessionInfo(2, 2000, 4000) + ) + assert(df.collect().toSet === expected.toSet) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateStoreTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateStoreTestBase.scala new file mode 100644 index 000000000000..ddc10bff8ddb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/state/StateStoreTestBase.scala @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.state + +import java.io.File +import java.sql.Timestamp + +import org.apache.spark.sql.{Dataset, Encoders} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.streaming.util.StreamManualClock +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} +import org.apache.spark.util.Utils + +trait StateStoreTestBase extends StreamTest { + import testImplicits._ + + override def afterAll(): Unit = { + super.afterAll() + StateStore.stop() + } + + protected def withTempCheckpoints(body: (File, File) => Unit) { + val src = Utils.createTempDir(namePrefix = "streaming.old") + val tmp = Utils.createTempDir(namePrefix = "streaming.new") + try { + body(src, tmp) + } finally { + Utils.deleteRecursively(src) + Utils.deleteRecursively(tmp) + } + } + + protected def runCompositeKeyStreamingAggregationQuery( + checkpointRoot: String): Unit = { + val inputData = MemoryStream[Int] + val aggregated = getCompositeKeyStreamingAggregationQuery(inputData) + + testStream(aggregated, OutputMode.Update)( + StartStream(checkpointLocation = checkpointRoot), + // batch 0 + AddData(inputData, 0 to 5: _*), + CheckLastBatch( + (0, "Apple", 1, 0, 0, 0), + (1, "Banana", 1, 1, 1, 1), + (0, "Strawberry", 1, 2, 2, 2), + (1, "Apple", 1, 3, 3, 3), + (0, "Banana", 1, 4, 4, 4), + (1, "Strawberry", 1, 5, 5, 5) + ), + // batch 1 + AddData(inputData, 6 to 10: _*), + // state also contains (1, "Strawberry", 1, 5, 5, 5) but not updated here + CheckLastBatch( + (0, "Apple", 2, 6, 6, 0), // 0, 6 + (1, "Banana", 2, 8, 7, 1), // 1, 7 + (0, "Strawberry", 2, 10, 8, 2), // 2, 8 + (1, "Apple", 2, 12, 9, 3), // 3, 9 + (0, "Banana", 2, 14, 10, 4) // 4, 10 + ), + StopStream, + StartStream(checkpointLocation = checkpointRoot), + // batch 2 + AddData(inputData, 3, 2, 1), + CheckLastBatch( + (1, "Banana", 3, 9, 7, 1), // 1, 7, 1 + (0, "Strawberry", 3, 12, 8, 2), // 2, 8, 2 + (1, "Apple", 3, 15, 9, 3) // 3, 9, 3 + ) + ) + } + + protected def getCompositeKeyStreamingAggregationQuery + : Dataset[(Int, String, Long, Long, Int, Int)] = { + getCompositeKeyStreamingAggregationQuery(MemoryStream[Int]) + } + + protected def getCompositeKeyStreamingAggregationQuery( + inputData: MemoryStream[Int]): Dataset[(Int, String, Long, Long, Int, Int)] = { + inputData.toDF() + .selectExpr("value", "value % 2 AS groupKey", + "(CASE value % 3 WHEN 0 THEN 'Apple' WHEN 1 THEN 'Banana' ELSE 'Strawberry' END) AS fruit") + .groupBy($"groupKey", $"fruit") + .agg( + count("*").as("cnt"), + sum("value").as("sum"), + max("value").as("max"), + min("value").as("min") + ) + .as[(Int, String, Long, Long, Int, Int)] + } + + protected def getSchemaForCompositeKeyStreamingAggregationQuery( + formatVersion: Int): StructType = { + val stateKeySchema = new StructType() + .add("groupKey", IntegerType) + .add("fruit", StringType, nullable = false) + + var stateValueSchema = formatVersion match { + case 1 => + new StructType().add("groupKey", IntegerType).add("fruit", StringType, nullable = false) + case 2 => new StructType() + case v => throw new IllegalArgumentException(s"Not valid format version $v") + } + + stateValueSchema = stateValueSchema + .add("cnt", LongType, nullable = false) + .add("sum", LongType) + .add("max", IntegerType) + .add("min", IntegerType) + + new StructType() + .add("key", stateKeySchema) + .add("value", stateValueSchema) + } + + protected def runLargeDataStreamingAggregationQuery( + checkpointRoot: String): Unit = { + val inputData = MemoryStream[Int] + val aggregated = getLargeDataStreamingAggregationQuery(inputData) + + // check with more data - leverage full partitions + testStream(aggregated, OutputMode.Update)( + StartStream(checkpointLocation = checkpointRoot), + // batch 0 + AddData(inputData, 0 until 20: _*), + CheckLastBatch( + (0, 2, 10, 10, 0), // 0, 10 + (1, 2, 12, 11, 1), // 1, 11 + (2, 2, 14, 12, 2), // 2, 12 + (3, 2, 16, 13, 3), // 3, 13 + (4, 2, 18, 14, 4), // 4, 14 + (5, 2, 20, 15, 5), // 5, 15 + (6, 2, 22, 16, 6), // 6, 16 + (7, 2, 24, 17, 7), // 7, 17 + (8, 2, 26, 18, 8), // 8, 18 + (9, 2, 28, 19, 9) // 9, 19 + ), + // batch 1 + AddData(inputData, 20 until 40: _*), + CheckLastBatch( + (0, 4, 60, 30, 0), // 0, 10, 20, 30 + (1, 4, 64, 31, 1), // 1, 11, 21, 31 + (2, 4, 68, 32, 2), // 2, 12, 22, 32 + (3, 4, 72, 33, 3), // 3, 13, 23, 33 + (4, 4, 76, 34, 4), // 4, 14, 24, 34 + (5, 4, 80, 35, 5), // 5, 15, 25, 35 + (6, 4, 84, 36, 6), // 6, 16, 26, 36 + (7, 4, 88, 37, 7), // 7, 17, 27, 37 + (8, 4, 92, 38, 8), // 8, 18, 28, 38 + (9, 4, 96, 39, 9) // 9, 19, 29, 39 + ), + StopStream, + StartStream(checkpointLocation = checkpointRoot), + // batch 2 + AddData(inputData, 0, 1, 2), + CheckLastBatch( + (0, 5, 60, 30, 0), // 0, 10, 20, 30, 0 + (1, 5, 65, 31, 1), // 1, 11, 21, 31, 1 + (2, 5, 70, 32, 2) // 2, 12, 22, 32, 2 + ) + ) + } + + protected def getLargeDataStreamingAggregationQuery: Dataset[(Int, Long, Long, Int, Int)] = { + getLargeDataStreamingAggregationQuery(MemoryStream[Int]) + } + + protected def getLargeDataStreamingAggregationQuery( + inputData: MemoryStream[Int]): Dataset[(Int, Long, Long, Int, Int)] = { + inputData.toDF() + .selectExpr("value", "value % 10 AS groupKey") + .groupBy($"groupKey") + .agg( + count("*").as("cnt"), + sum("value").as("sum"), + max("value").as("max"), + min("value").as("min") + ) + .as[(Int, Long, Long, Int, Int)] + } + + protected def getSchemaForLargeDataStreamingAggregationQuery(formatVersion: Int): StructType = { + val stateKeySchema = new StructType() + .add("groupKey", IntegerType) + + var stateValueSchema = formatVersion match { + case 1 => new StructType().add("groupKey", IntegerType) + case 2 => new StructType() + case v => throw new IllegalArgumentException(s"Not valid format version $v") + } + + stateValueSchema = stateValueSchema + .add("cnt", LongType) + .add("sum", LongType) + .add("max", IntegerType) + .add("min", IntegerType) + + new StructType() + .add("key", stateKeySchema) + .add("value", stateValueSchema) + } + + protected def runFlatMapGroupsWithStateQuery(checkpointRoot: String): Unit = { + val clock = new StreamManualClock + + val inputData = MemoryStream[(String, Long)] + val remapped = getFlatMapGroupsWithStateQuery(inputData) + + testStream(remapped, OutputMode.Update)( + // batch 0 + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock, + checkpointLocation = checkpointRoot), + AddData(inputData, ("hello world", 1L), ("hello scala", 2L)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer( + ("hello", 2, 1000, false), + ("world", 1, 0, false), + ("scala", 1, 0, false) + ), + // batch 1 + AddData(inputData, ("hello world", 3L), ("hello scala", 4L)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer( + ("hello", 4, 3000, false), + ("world", 2, 2000, false), + ("scala", 2, 2000, false) + ) + ) + } + + protected def getFlatMapGroupsWithStateQuery: Dataset[(String, Int, Long, Boolean)] = { + getFlatMapGroupsWithStateQuery(MemoryStream[(String, Long)]) + } + + protected def getFlatMapGroupsWithStateQuery( + inputData: MemoryStream[(String, Long)]): Dataset[(String, Int, Long, Boolean)] = { + // scalastyle:off line.size.limit + // This test code is borrowed from Sessionization example, with modification a bit to run with testStream + // https://github.com/apache/spark/blob/v2.4.1/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala + // scalastyle:on + + val events = inputData.toDF() + .as[(String, Timestamp)] + .flatMap { case (line, timestamp) => + line.split(" ").map(word => Event(sessionId = word, timestamp)) + } + + val sessionUpdates = events + .groupByKey(event => event.sessionId) + .mapGroupsWithState[SessionInfo, SessionUpdate](GroupStateTimeout.ProcessingTimeTimeout) { + + case (sessionId: String, events: Iterator[Event], state: GroupState[SessionInfo]) => + if (state.hasTimedOut) { + val finalUpdate = + SessionUpdate(sessionId, state.get.durationMs, state.get.numEvents, expired = true) + state.remove() + finalUpdate + } else { + val timestamps = events.map(_.timestamp.getTime).toSeq + val updatedSession = if (state.exists) { + val oldSession = state.get + SessionInfo( + oldSession.numEvents + timestamps.size, + oldSession.startTimestampMs, + math.max(oldSession.endTimestampMs, timestamps.max)) + } else { + SessionInfo(timestamps.size, timestamps.min, timestamps.max) + } + state.update(updatedSession) + + state.setTimeoutDuration("10 seconds") + SessionUpdate(sessionId, state.get.durationMs, state.get.numEvents, expired = false) + } + } + + sessionUpdates.map(si => (si.id, si.numEvents, si.durationMs, si.expired)) + } + + protected def getSchemaForFlatMapGroupsWithStateQuery(stateVersion: Int): StructType = { + val keySchema = new StructType().add("value", StringType, nullable = true) + val valueSchema = if (stateVersion == 1) { + Encoders.product[SessionInfo].schema.add("timeoutTimestamp", IntegerType, nullable = false) + } else { // stateVersion == 2 + new StructType() + .add("groupState", Encoders.product[SessionInfo].schema) + .add("timeoutTimestamp", LongType, nullable = false) + } + + new StructType().add("key", keySchema).add("value", valueSchema) + } +} + +case class Event(sessionId: String, timestamp: Timestamp) + +case class SessionInfo( + numEvents: Int, + startTimestampMs: Long, + endTimestampMs: Long) { + def durationMs: Long = endTimestampMs - startTimestampMs +} + +case class SessionUpdate( + id: String, + durationMs: Long, + numEvents: Int, + expired: Boolean)