diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index 63d118cf857b..99c6df738920 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.benchmark import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation @@ -29,9 +30,9 @@ import org.apache.spark.util.Benchmark /** * Benchmark to measure TPCDS query performance. * To run this: - * spark-submit --class + * spark-submit --class --data-location */ -object TPCDSQueryBenchmark { +object TPCDSQueryBenchmark extends Logging { val conf = new SparkConf() .setMaster("local[1]") @@ -90,7 +91,9 @@ object TPCDSQueryBenchmark { benchmark.addCase(name) { i => spark.sql(queryString).collect() } + logInfo(s"\n\n===== TPCDS QUERY BENCHMARK OUTPUT FOR $name =====\n") benchmark.run() + logInfo(s"\n\n===== FINISHED $name =====\n") } } @@ -110,6 +113,20 @@ object TPCDSQueryBenchmark { "q81", "q82", "q83", "q84", "q85", "q86", "q87", "q88", "q89", "q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99") - tpcdsAll(benchmarkArgs.dataLocation, queries = tpcdsQueries) + // If `--query-filter` defined, filters the queries that this option selects + val queriesToRun = if (benchmarkArgs.queryFilter.nonEmpty) { + val queries = tpcdsQueries.filter { case queryName => + benchmarkArgs.queryFilter.contains(queryName) + } + if (queries.isEmpty) { + throw new RuntimeException( + s"Empty queries to run. Bad query name filter: ${benchmarkArgs.queryFilter}") + } + queries + } else { + tpcdsQueries + } + + tpcdsAll(benchmarkArgs.dataLocation, queries = queriesToRun) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmarkArguments.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmarkArguments.scala index 8edc77bd0ec6..184ffff94298 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmarkArguments.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmarkArguments.scala @@ -17,21 +17,33 @@ package org.apache.spark.sql.execution.benchmark +import java.util.Locale + + class TPCDSQueryBenchmarkArguments(val args: Array[String]) { var dataLocation: String = null + var queryFilter: Set[String] = Set.empty parseArgs(args.toList) validateArguments() + private def optionMatch(optionName: String, s: String): Boolean = { + optionName == s.toLowerCase(Locale.ROOT) + } + private def parseArgs(inputArgs: List[String]): Unit = { var args = inputArgs - while(args.nonEmpty) { + while (args.nonEmpty) { args match { - case ("--data-location") :: value :: tail => + case optName :: value :: tail if optionMatch("--data-location", optName) => dataLocation = value args = tail + case optName :: value :: tail if optionMatch("--query-filter", optName) => + queryFilter = value.toLowerCase(Locale.ROOT).split(",").map(_.trim).toSet + args = tail + case _ => // scalastyle:off println System.err.println("Unknown/unsupported param " + args) @@ -47,6 +59,7 @@ class TPCDSQueryBenchmarkArguments(val args: Array[String]) { |Usage: spark-submit --class [Options] |Options: | --data-location Path to TPCDS data + | --query-filter Queries to filter, e.g., q3,q5,q13 | |------------------------------------------------------------------------------------------------------------------ |In order to run this benchmark, please follow the instructions at