diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index d2d013682cd2..63d118cf857b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.benchmark -import java.io.File - import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.TableIdentifier @@ -31,7 +29,7 @@ import org.apache.spark.util.Benchmark /** * Benchmark to measure TPCDS query performance. * To run this: - * spark-submit --class --jars + * spark-submit --class */ object TPCDSQueryBenchmark { val conf = @@ -61,12 +59,10 @@ object TPCDSQueryBenchmark { } def tpcdsAll(dataLocation: String, queries: Seq[String]): Unit = { - require(dataLocation.nonEmpty, - "please modify the value of dataLocation to point to your local TPCDS data") val tableSizes = setupTables(dataLocation) queries.foreach { name => - val queryString = fileToString(new File(Thread.currentThread().getContextClassLoader - .getResource(s"tpcds/$name.sql").getFile)) + val queryString = resourceToString(s"tpcds/$name.sql", + classLoader = Thread.currentThread().getContextClassLoader) // This is an indirect hack to estimate the size of each query's input by traversing the // logical plan and adding up the sizes of all tables that appear in the plan. Note that this @@ -99,6 +95,7 @@ object TPCDSQueryBenchmark { } def main(args: Array[String]): Unit = { + val benchmarkArgs = new TPCDSQueryBenchmarkArguments(args) // List of all TPC-DS queries val tpcdsQueries = Seq( @@ -113,12 +110,6 @@ object TPCDSQueryBenchmark { "q81", "q82", "q83", "q84", "q85", "q86", "q87", "q88", "q89", "q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99") - // In order to run this benchmark, please follow the instructions at - // https://github.com/databricks/spark-sql-perf/blob/master/README.md to generate the TPCDS data - // locally (preferably with a scale factor of 5 for benchmarking). Thereafter, the value of - // dataLocation below needs to be set to the location where the generated data is stored. - val dataLocation = "" - - tpcdsAll(dataLocation, queries = tpcdsQueries) + tpcdsAll(benchmarkArgs.dataLocation, queries = tpcdsQueries) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmarkArguments.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmarkArguments.scala new file mode 100644 index 000000000000..8edc77bd0ec6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmarkArguments.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +class TPCDSQueryBenchmarkArguments(val args: Array[String]) { + var dataLocation: String = null + + parseArgs(args.toList) + validateArguments() + + private def parseArgs(inputArgs: List[String]): Unit = { + var args = inputArgs + + while(args.nonEmpty) { + args match { + case ("--data-location") :: value :: tail => + dataLocation = value + args = tail + + case _ => + // scalastyle:off println + System.err.println("Unknown/unsupported param " + args) + // scalastyle:on println + printUsageAndExit(1) + } + } + } + + private def printUsageAndExit(exitCode: Int): Unit = { + // scalastyle:off + System.err.println(""" + |Usage: spark-submit --class [Options] + |Options: + | --data-location Path to TPCDS data + | + |------------------------------------------------------------------------------------------------------------------ + |In order to run this benchmark, please follow the instructions at + |https://github.com/databricks/spark-sql-perf/blob/master/README.md + |to generate the TPCDS data locally (preferably with a scale factor of 5 for benchmarking). + |Thereafter, the value of needs to be set to the location where the generated data is stored. + """.stripMargin) + // scalastyle:on + System.exit(exitCode) + } + + private def validateArguments(): Unit = { + if (dataLocation == null) { + // scalastyle:off println + System.err.println("Must specify a data location") + // scalastyle:on println + printUsageAndExit(-1) + } + } +}